├── cache ├── valid_years.pt ├── logit_indices.pt ├── potential_nouns.txt └── gelu_12_tied.circ ├── requirements.txt ├── topk_years.py ├── random_years.py ├── color_utils.py ├── README.md ├── pca_plots.py ├── dataset.py ├── random_circuit_ablation.py ├── neuron_plots.py ├── neuron_investigations.py ├── big_ds_experiments.py ├── appendix_plots.py ├── circuit_discovery_plotting.py ├── sequence_generalization.py ├── circuit_discovery.py └── utils.py /cache/valid_years.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hannamw/gpt2-greater-than/HEAD/cache/valid_years.pt -------------------------------------------------------------------------------- /cache/logit_indices.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hannamw/gpt2-greater-than/HEAD/cache/logit_indices.pt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | attrs 2 | einops 3 | jupyter 4 | ipywidgets 5 | ipykernel 6 | scikit-learn 7 | requests 8 | pandas 9 | fancy_einsum 10 | seaborn 11 | tqdm 12 | websockets 13 | get-mnist 14 | transformers 15 | tabulate 16 | plotly 17 | black 18 | matplotlib 19 | imageio 20 | pillow -------------------------------------------------------------------------------- /cache/potential_nouns.txt: -------------------------------------------------------------------------------- 1 | abduction 2 | accord 3 | affair 4 | agreement 5 | appraisal 6 | assaults 7 | assessment 8 | attack 9 | attempts 10 | campaign 11 | captivity 12 | case 13 | challenge 14 | chaos 15 | clash 16 | collaboration 17 | coma 18 | competition 19 | confrontation 20 | consequence 21 | conspiracy 22 | construction 23 | consultation 24 | contact 25 | contract 26 | convention 27 | cooperation 28 | custody 29 | deal 30 | decline 31 | decrease 32 | demonstrations 33 | development 34 | disagreement 35 | disorder 36 | dispute 37 | domination 38 | dynasty 39 | effect 40 | effort 41 | employment 42 | endeavor 43 | engagement 44 | epidemic 45 | evaluation 46 | exchange 47 | existence 48 | expansion 49 | expedition 50 | experiments 51 | fall 52 | fame 53 | flights 54 | friendship 55 | growth 56 | hardship 57 | hostility 58 | illness 59 | impact 60 | imprisonment 61 | improvement 62 | incarceration 63 | increase 64 | insurgency 65 | invasion 66 | investigation 67 | journey 68 | kingdom 69 | marriage 70 | modernization 71 | negotiation 72 | notoriety 73 | obstruction 74 | operation 75 | order 76 | outbreak 77 | outcome 78 | overhaul 79 | patrols 80 | pilgrimage 81 | plague 82 | plan 83 | practice 84 | process 85 | program 86 | progress 87 | project 88 | pursuit 89 | quest 90 | raids 91 | reforms 92 | reign 93 | relationship 94 | retaliation 95 | riot 96 | rise 97 | rivalry 98 | romance 99 | rule 100 | sanctions 101 | shift 102 | siege 103 | slump 104 | stature 105 | stint 106 | strikes 107 | study 108 | test 109 | testing 110 | tests 111 | therapy 112 | tour 113 | tradition 114 | treaty 115 | trial 116 | trip 117 | unemployment 118 | voyage 119 | warfare 120 | work 121 | -------------------------------------------------------------------------------- /topk_years.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from pathlib import Path 3 | 4 | import matplotlib.pyplot as plt 5 | from transformers import AutoTokenizer, AutoModelForCausalLM 6 | import torch 7 | 8 | from dataset import YearDataset 9 | from utils import get_valid_years 10 | 11 | DEVICE='cuda:0' 12 | #%% 13 | model = AutoModelForCausalLM.from_pretrained('gpt2') 14 | model.to(DEVICE) 15 | tokenizer = AutoTokenizer.from_pretrained('gpt2') 16 | years_to_sample_from = get_valid_years(tokenizer, 1000, 1900) 17 | N = 400 18 | ds = YearDataset(years_to_sample_from, N, Path("cache/potential_nouns.txt"), tokenizer, balanced=False, device=DEVICE, eos=True) 19 | 20 | # %% 21 | with torch.inference_mode(): 22 | logits = model(ds.good_toks).logits 23 | # %% 24 | year_probs = torch.softmax(logits[:, -2], dim=-1) 25 | topk = torch.topk(year_probs, k=100) 26 | topk_tokens = [[tokenizer._convert_id_to_token(top) for top in ex] for ex in topk.indices] 27 | topk_numbers = torch.tensor([[int(tok[1:]) if tok[1:].isnumeric() else 0 for tok in ex] for ex in topk_tokens], device=DEVICE) 28 | zeros = torch.zeros_like(topk.values, device=DEVICE) 29 | # %% 30 | valid_prob = torch.where(topk_numbers >= ds.years_XX.view(-1,1).to(DEVICE), topk.values, zeros) 31 | print(valid_prob.sum(-1).mean(), topk.values.sum(-1).mean()) 32 | 33 | # %% 34 | year_prob = torch.where(topk_numbers == ds.years_XX.view(-1,1).to(DEVICE), topk.values, zeros) 35 | print(year_prob.sum(-1).mean(), topk.values.sum(-1).mean()) 36 | # %% 37 | 38 | year_probs = torch.softmax(logits[:, -1], dim=-1) 39 | topk = torch.topk(year_probs, k=5) 40 | topk_tokens = [[tokenizer._convert_id_to_token(top) for top in ex] for ex in topk.indices] 41 | topk_numbers = torch.tensor([[int(tok) if tok.isnumeric() else 0 for tok in ex] for ex in topk_tokens], device=DEVICE) 42 | zeros = torch.zeros_like(topk.values, device=DEVICE) 43 | valid_prob = torch.where(topk_numbers >= ds.years_YY.view(-1,1).to(DEVICE), topk.values, zeros) 44 | print(valid_prob.sum(-1).mean(), topk.values.sum(-1).mean()) 45 | # %% 46 | print((topk_numbers >= ds.years_YY.view(-1,1).to(DEVICE)).float().sum(-1).mean()) -------------------------------------------------------------------------------- /random_years.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from pathlib import Path 3 | 4 | import matplotlib.pyplot as plt 5 | from transformers import AutoTokenizer, AutoModelForCausalLM 6 | import torch 7 | 8 | from dataset import YearDataset 9 | from utils import year_indices, get_valid_years 10 | 11 | DEVICE='cuda:0' 12 | #%% 13 | model = AutoModelForCausalLM.from_pretrained('gpt2') 14 | model.to(DEVICE) 15 | tokenizer = AutoTokenizer.from_pretrained('gpt2') 16 | years_to_sample_from = get_valid_years(tokenizer, 1000, 1900) 17 | N = 400 18 | ds = YearDataset(years_to_sample_from, N, Path("cache/potential_nouns.txt"), tokenizer, balanced=False, device=DEVICE, eos=True) 19 | 20 | # %% 21 | N_prior = 200 22 | start_years = years_to_sample_from[torch.randint(years_to_sample_from.size(0), (N_prior,))] 23 | years_XX = start_years // 100 24 | years_XX00 = (start_years // 100) * 100 25 | years = [] 26 | for XX00 in years_XX00: 27 | sample_space = years_to_sample_from[(years_to_sample_from >= XX00) & (years_to_sample_from < XX00+100)] 28 | years.append(sample_space[torch.randint(sample_space.size(0), (5,))]) 29 | years = torch.stack(years) 30 | years_YY = years % 100 31 | 32 | years_strings = [f'{str(y.tolist())[1:-1]}, {XX}' for y, XX in zip(years, years_XX)] 33 | years_tokens = tokenizer(years_strings, return_tensors="pt")['input_ids'].to(DEVICE) 34 | 35 | with torch.inference_mode(): 36 | logits = model(years_tokens).logits[:, -1] 37 | probs = torch.softmax(logits, dim=-1) 38 | year_probs = probs[:, year_indices] 39 | topk = torch.topk(probs, k=5) 40 | topk_tokens = [[tokenizer._convert_id_to_token(top) for top in ex] for ex in topk.indices] 41 | 42 | def comp_prob(probs, years_YY, gt=True): 43 | comps = [] 44 | for prob, year in zip(probs, years_YY[:, -1]): 45 | if gt: 46 | comps.append(prob[year+1:].sum()) 47 | else: 48 | comps.append(prob[:year+1].sum()) 49 | return torch.stack(comps) 50 | print(year_probs[torch.arange(year_probs.size(0)), years_YY[:, -1]+1].mean()) 51 | print(comp_prob(year_probs, years_YY).mean()) 52 | print(comp_prob(year_probs, years_YY, False).mean()) 53 | # %% 54 | i = 4 55 | plt.plot(year_probs[i].cpu()) 56 | plt.title(f"GPT-2 Probabilities when YY={years_YY[i].tolist()}") 57 | plt.xlabel(f"Predicted Year") 58 | plt.ylabel(f"probability") 59 | plt.show() 60 | # %% 61 | -------------------------------------------------------------------------------- /color_utils.py: -------------------------------------------------------------------------------- 1 | # These color utils (for making pretty PCA plots) are adapted from https://bsouthga.dev/posts/color-gradients-with-python 2 | 3 | def hex_to_RGB(hex): 4 | """ "#FFFFFF" -> [255,255,255]""" 5 | # Pass 16 to the integer function for change of base 6 | return [int(hex[i : i + 2], 16) for i in range(1, 6, 2)] 7 | 8 | 9 | def RGB_to_hex(RGB): 10 | '''[255,255,255] -> "#FFFFFF"''' 11 | # Components need to be integers for hex to make sense 12 | RGB = [int(x) for x in RGB] 13 | return "#" + "".join(["0{0:x}".format(v) if v < 16 else "{0:x}".format(v) for v in RGB]) 14 | 15 | 16 | def color_dict(gradient): 17 | """Takes in a list of RGB sub-lists and returns dictionary of 18 | colors in RGB and hex form for use in a graphing function 19 | defined later on""" 20 | return { 21 | "hex": [RGB_to_hex(RGB) for RGB in gradient], 22 | "r": [RGB[0] for RGB in gradient], 23 | "g": [RGB[1] for RGB in gradient], 24 | "b": [RGB[2] for RGB in gradient], 25 | } 26 | 27 | 28 | def linear_gradient(start_hex, finish_hex="#FFFFFF", n=10): 29 | """returns a gradient list of (n) colors between 30 | two hex colors. start_hex and finish_hex 31 | should be the full six-digit color string, 32 | inlcuding the number sign ("#FFFFFF")""" 33 | # Starting and ending colors in RGB form 34 | s = hex_to_RGB(start_hex) 35 | f = hex_to_RGB(finish_hex) 36 | # Initilize a list of the output colors with the starting color 37 | RGB_list = [s] 38 | # Calcuate a color at each evenly spaced value of t from 1 to n 39 | for t in range(1, n): 40 | # Interpolate RGB vector for color at the current value of t 41 | curr_vector = [int(s[j] + (float(t) / (n - 1)) * (f[j] - s[j])) for j in range(3)] 42 | # Add it to our list of output colors 43 | RGB_list.append(curr_vector) 44 | 45 | return color_dict(RGB_list) 46 | 47 | 48 | def polylinear_gradient(colors, n): 49 | """returns a list of colors forming linear gradients between 50 | all sequential pairs of colors. "n" specifies the total 51 | number of desired output colors""" 52 | # The number of colors per individual linear gradient 53 | n_out = int(float(n) / (len(colors) - 1)) 54 | # returns dictionary defined by color_dict() 55 | gradient_dict = linear_gradient(colors[0], colors[1], n_out) 56 | 57 | if len(colors) > 1: 58 | for col in range(1, len(colors) - 1): 59 | next = linear_gradient(colors[col], colors[col + 1], n_out) 60 | for k in ("hex", "r", "g", "b"): 61 | # Exclude first point to avoid duplicates 62 | gradient_dict[k] += next[k][1:] 63 | 64 | return gradient_dict 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code Release 2 | Here is a code release for the 2023 NeurIPS paper "How does GPT-2 compute greater-than?: Interpreting mathematical abilities in a pre-trained language model". 3 | 4 | The code release is structured as follows: 5 | - `circuit_discovery.py` reproduces the circuit discovery and semantics assignment process 6 | - `big_ds_experiments.py` reproduces the experiments run on the larger, 10,000 element dataset 7 | - `neuron_investigations.py` reproduces the neuron-level experiments 8 | - `sequence_generalization.py` reproduces the generalization experiments 9 | 10 | The aforementioned files will cache files (in `paper-cache`) that can be used to generate plots using these scripts: 11 | - `circuit_discovery_plotting.py` 12 | - `pca_plots.py` 13 | - `neuron_plots.py` 14 | - `appendix_plots.py` 15 | 16 | In addition, we include three useful files in the `cache` folder (indices of the relevant logits, the nouns used in our template, and the order of MLP10 neurons, which otherwise takes a long time to compute). Finally, we include two utility files, `utils.py` and `color_utils.py` (for plotting). 17 | 18 | Most of these experiments started as exploratory VSCode notebooks, but can be run just as easily as Python scripts, and will produce all necessary output. We also include a few smaller notebooks that correspond to discussions with reviewers, and didn't fit in neatly with the rest of our experiments: 19 | - `random_circuit_ablation.py`: allows you to try ablating random circuits, as opposed to the one we found 20 | - `random_years.py`: tests GPT-2's responses to random sequences of years from the same century 21 | - `topk_years.py`: tests the degree to which GPT-2's top-k YY predictions are correct. 22 | 23 | ## Running the code 24 | Unfortunately, using the `rust-circuit` library to work with `gpt2-small` is not easy. To run the code, follow these steps: 25 | 26 | 1. Compile [rust-circuit](https://github.com/redwoodresearch/rust_circuit_public), following the instructions there given; note that this requires clang and rust. The repo instructs you to install maturin; be sure to install 0.14.x (we used 0.14.7), as newer versions do not work. 27 | 2. Install this project's requirements via the provided requirements file `pip install -r requirements.txt` 28 | 3. Download the `gpt2-small` model files from [this link](https://rrserve.s3.us-west-2.amazonaws.com/remix/remix_tensors.zip). Extract them to a folder called `../rrfs/tensor_db`. 29 | 30 | # The paper 31 | Our paper is available on [ArXiv](https://arxiv.org/abs/2305.00586) and on OpenReview / the NeurIPS website. You can cite it like this: 32 | 33 | ``` 34 | @inproceedings{ 35 | hanna2023how, 36 | title={How does {GPT}-2 compute greater-than?: Interpreting mathematical abilities in a pre-trained language model}, 37 | author={Michael Hanna and Ollie Liu and Alexandre Variengien}, 38 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 39 | year={2023}, 40 | url={https://openreview.net/forum?id=p4PckNQR8k} 41 | } 42 | ``` 43 | -------------------------------------------------------------------------------- /pca_plots.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from pathlib import Path 3 | 4 | import torch 5 | import rust_circuit as rc 6 | import matplotlib.pyplot as plt 7 | 8 | from dataset import YearDataset 9 | from utils import ( 10 | load_gpt2_small_circuit, 11 | load_diff_model, 12 | to_device, 13 | load_and_split_gpt2, 14 | ) 15 | 16 | from color_utils import RGB_to_hex, polylinear_gradient 17 | from sklearn.decomposition import PCA 18 | import numpy as np 19 | 20 | #%% 21 | # Loading our base model 22 | DEVICE = "cuda" 23 | _, tokenizer, _ = load_gpt2_small_circuit() 24 | 25 | #%% 26 | # Creating our dataset 27 | years_to_sample_from = torch.arange(1702, 1799) 28 | N = len(years_to_sample_from) 29 | ds = YearDataset(years_to_sample_from, N, Path("cache/potential_nouns.txt"), ordered=True, device=DEVICE) 30 | 31 | MAX_LEN = ds.good_toks.size(-1) 32 | END_POS = MAX_LEN - 1 33 | XX1_POS = ds.good_prompt.index("XX1") 34 | YY_POS = ds.good_prompt.index("YY") 35 | last_two_digits = ds.years_YY 36 | 37 | #%% 38 | # Splitting our model to make it pretty 39 | metric = "prob" 40 | circuit = load_and_split_gpt2(MAX_LEN) 41 | year_indices = torch.load("cache/logit_indices.pt") 42 | ld_circuit, group = load_diff_model(circuit, year_indices, ds.good_mask, logit_diff=False, device=DEVICE) 43 | 44 | #%% 45 | def se(c): 46 | """Short function for Sample and Evaluate along the global variable `group`""" 47 | transform = rc.Sampler(rc.RunDiscreteVarAllSpec([group])) 48 | return transform.sample(c).evaluate() 49 | 50 | 51 | # Let's make a copy of the circuit that actually has inputs! 52 | c = ld_circuit.update( 53 | "tokens", 54 | lambda _: rc.DiscreteVar(to_device(rc.Array(ds.good_toks, name="tokens"), DEVICE), probs_and_group=group), 55 | ) 56 | 57 | #%% 58 | BASE_PALETTE = [ 59 | [255, 128, 0], 60 | [128, 255, 0], 61 | [128, 0, 255], 62 | [0, 128, 255], 63 | [0, 255, 128], 64 | [255, 0, 255], 65 | [255, 255, 0], 66 | [0, 255, 255], 67 | [0, 76, 153], 68 | [76, 0, 153], 69 | ] 70 | color_gradient = polylinear_gradient([RGB_to_hex(color) for color in BASE_PALETTE], 110)["hex"][2:99] 71 | print(len(color_gradient)) 72 | 73 | 74 | def plot(reps, title="PCA for Embeddings"): 75 | reducer = PCA(n_components=2) 76 | fits = np.array(reducer.fit_transform(reps)) 77 | print(reducer.explained_variance_ratio_) 78 | return raw_plot(fits, title) 79 | 80 | 81 | def raw_plot(fits, title="PCA for Embeddings"): 82 | fig, ax = plt.subplots() 83 | ax.scatter(fits[:, 0], fits[:, 1], c=color_gradient) 84 | 85 | texts = [str(i) for i in range(2, 99)] 86 | for i, txt in enumerate(texts): 87 | ax.annotate(txt, (fits[i, 0], fits[i, 1]), xytext=(2, 2), textcoords="offset points") 88 | 89 | ax.set_title(title) 90 | return fig 91 | # %% 92 | a7h10_reps = se(c.get_unique("a7_h10_t11")).squeeze().cpu().numpy() 93 | fig = plot(a7h10_reps, title="PCA of a7.h10 outputs") 94 | # %% 95 | a7h8_reps = se(c.get_unique("a7_h8_t11")).squeeze().cpu().numpy() 96 | fig = plot(a7h8_reps, title="PCA of a7.h8 outputs") 97 | #%% 98 | m8_input = se(c.get_unique("b8.m.input"))[:, 11].squeeze().cpu().numpy() 99 | fig = plot(m8_input, title="PCA of MLP 8 input") 100 | # %% 101 | embeds = se(c.get_unique("embeds"))[:, 7].squeeze().cpu().numpy() 102 | fig = plot(embeds, title="PCA of static embeddings") 103 | #%% 104 | fig, axs = plt.subplots(1, 4, sharey=True) 105 | remove_ticks = False 106 | for reps, ax, name in zip( 107 | [m8_input, a7h10_reps, a7h8_reps, embeds], 108 | axs, 109 | ["MLP 8 Input", "a7.h10 Output", "a7.h8 Output", "Static Embeddings"], 110 | ): 111 | reducer = PCA(n_components=2) 112 | fits = np.array(reducer.fit_transform(reps)) 113 | fits[:, 0] = (fits[:, 0] - fits[:, 0].mean()) / (fits[:, 0].std()) 114 | fits[:, 1] = (fits[:, 1] - fits[:, 1].mean()) / (fits[:, 1].std()) 115 | ax.scatter(fits[:, 0], fits[:, 1], c=color_gradient) 116 | 117 | texts = [f"{i:02d}" for i in range(2, 99)] 118 | for i, txt in enumerate(texts): 119 | ax.annotate(txt, (fits[i, 0], fits[i, 1]), xytext=(2, 2), textcoords="offset points") 120 | ax.set_title(f"PCA of {name}") 121 | # ax.set_ylim(-2,2) 122 | if remove_ticks: 123 | ax.tick_params( 124 | axis="y", 125 | left=False, 126 | ) 127 | else: 128 | remove_ticks = True 129 | fig.set_size_inches([20, 6]) 130 | fig.set_dpi(200) 131 | fig.tight_layout() 132 | fig.savefig("paper-plots/pca-analysis.pdf") 133 | fig.show() 134 | 135 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import List, Union 3 | from pathlib import Path 4 | 5 | import torch 6 | from transformers import PreTrainedTokenizer 7 | 8 | 9 | def generate_real_sentence(noun: str, year: int, eos: bool = False) -> str: 10 | century = year // 100 11 | sentence = f"The {noun} lasted from the year {year} to the year {century}" 12 | if eos: 13 | sentence = "<|endoftext|> " + sentence 14 | return sentence 15 | 16 | 17 | def real_sentence_prompt(eos: bool = False) -> List[str]: 18 | sentence = f"The NOUN lasted from the year XX1 YY to the year XX2".split() 19 | if eos: 20 | sentence = ["<|endoftext|>"] + sentence 21 | return sentence 22 | 23 | 24 | def generate_bad_sentence(noun: str, year: int, eos: bool = False) -> str: 25 | century = year // 100 26 | sentence = f"The {noun} lasted from the year {century}01 to the year {century}" 27 | if eos: 28 | sentence = "<|endoftext|> " + sentence 29 | return sentence 30 | 31 | 32 | def bad_sentence_prompt(eos: bool = False) -> List[str]: 33 | sentence = f"The NOUN lasted from the year XX1 01 to the year XX2".split() 34 | if eos: 35 | sentence = ["<|endoftext|>"] + sentence 36 | return sentence 37 | 38 | 39 | def is_valid_year(year: str, tokenizer) -> bool: 40 | _year = " " + year 41 | token = tokenizer(_year)["input_ids"] 42 | detok = tokenizer.convert_ids_to_tokens(token) 43 | return len(detok) == 2 and len(detok[1]) == 2 44 | 45 | 46 | class YearDataset: 47 | years_to_sample_from: torch.Tensor 48 | N: int 49 | ordered: bool 50 | eos: bool 51 | 52 | nouns: List[str] 53 | years: torch.Tensor 54 | years_YY: torch.Tensor 55 | good_sentences: List[str] 56 | bad_sentences: List[str] 57 | good_toks: torch.Tensor 58 | bad_toks: torch.Tensor 59 | good_prompt: List[str] 60 | bad_prompt: List[str] 61 | good_mask: torch.Tensor 62 | tokenizer: PreTrainedTokenizer 63 | 64 | def __init__( 65 | self, 66 | years_to_sample_from, 67 | N: int, 68 | nouns: Union[str, List[str], Path], 69 | tokenizer: PreTrainedTokenizer, 70 | balanced: bool = True, 71 | eos: bool = False, 72 | device: str = "cpu", 73 | ): 74 | self.years_to_sample_from = years_to_sample_from 75 | self.N = N 76 | self.eos=eos 77 | 78 | if isinstance(nouns, str): 79 | noun_list = [nouns] 80 | elif isinstance(nouns, list): 81 | noun_list = nouns 82 | elif isinstance(nouns, Path): 83 | with open(nouns, "r") as f: 84 | noun_list = [line.strip() for line in f] 85 | else: 86 | raise ValueError(f"Got bad type of nouns: {type(nouns)}; for nouns: {nouns}") 87 | 88 | self.nouns = random.choices(noun_list, k=N) 89 | 90 | if balanced: 91 | years = [] 92 | current_year = 2 93 | years_to_sample_from_YY = self.years_to_sample_from % 100 94 | for i in range(N): 95 | sample_pool = self.years_to_sample_from[years_to_sample_from_YY == current_year] 96 | years.append(sample_pool[random.randrange(len(sample_pool))]) 97 | current_year += 1 98 | if current_year >= 99: 99 | current_year -= 97 100 | self.years = torch.tensor(years) 101 | else: 102 | self.years = torch.tensor(self.years_to_sample_from[torch.randint(0, len(self.years_to_sample_from), (N,))]) 103 | 104 | self.years_XX = self.years // 100 105 | self.years_YY = self.years % 100 106 | 107 | self.good_sentences = [ 108 | generate_real_sentence(noun, int(year.item()), eos=eos) for noun, year in zip(self.nouns, self.years) 109 | ] 110 | self.bad_sentences = [ 111 | generate_bad_sentence(noun, int(year.item()), eos=eos) for noun, year in zip(self.nouns, self.years) 112 | ] 113 | 114 | self.good_prompt = real_sentence_prompt(eos=eos) 115 | self.bad_prompt = bad_sentence_prompt(eos=eos) 116 | 117 | good_tokenized = tokenizer(self.good_sentences, return_tensors="pt") 118 | self.good_toks, good_attn = good_tokenized["input_ids"], good_tokenized["attention_mask"] 119 | assert torch.all(good_attn == 1) 120 | bad_tokenized = tokenizer(self.bad_sentences, return_tensors="pt") 121 | self.bad_toks, bad_attn = bad_tokenized["input_ids"], bad_tokenized["attention_mask"] 122 | assert torch.all(bad_attn == 1) 123 | 124 | # there's a better way to do this 125 | _good_logits_masks = [] 126 | for year in self.years_YY: 127 | logits_mask = torch.arange(100) 128 | _good_logits_masks.append(logits_mask > year) 129 | self.good_mask = torch.stack(_good_logits_masks) 130 | 131 | self.good_toks = self.good_toks.to(device) 132 | self.bad_toks = self.bad_toks.to(device) 133 | self.good_mask = self.good_mask.to(device) 134 | 135 | def __len__(self): 136 | return self.N 137 | 138 | -------------------------------------------------------------------------------- /random_circuit_ablation.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from typing import cast 3 | from pathlib import Path 4 | 5 | import torch 6 | import rust_circuit as rc 7 | from rust_circuit.causal_scrubbing.hypothesis import corr_root_matcher 8 | 9 | from dataset import YearDataset 10 | from utils import ( 11 | collate, 12 | HeadOrMlpType, 13 | MLPHeadAndPosSpec, 14 | load_gpt2_small_circuit, 15 | load_diff_model, 16 | iterative_path_patching_nocorr, 17 | path_patching, 18 | to_device, 19 | load_and_split_gpt2, 20 | show_diffs, 21 | get_valid_years, 22 | make_extender_factory, 23 | get_valid_years, 24 | ) 25 | 26 | #%% 27 | # Loading our base model 28 | DEVICE = "cuda:0" 29 | MODEL_ID = "gelu_12_tied" # aka gpt2 small 30 | _, tokenizer, _ = load_gpt2_small_circuit() 31 | 32 | #%% 33 | # Creating our dataset 34 | years_to_sample_from = get_valid_years(tokenizer, 1000, 1900) 35 | N = 200 36 | ds = YearDataset(years_to_sample_from, N, Path("cache/potential_nouns.txt"), balanced=True, device=DEVICE, eos=True) 37 | 38 | MAX_LEN = ds.good_toks.size(-1) 39 | END_POS = MAX_LEN - 1 40 | XX1_POS = ds.good_prompt.index("XX1") 41 | YY_POS = ds.good_prompt.index("YY") 42 | last_two_digits = ds.years_YY 43 | 44 | #%% 45 | # Splitting our model to make it pretty 46 | circuit = load_and_split_gpt2(MAX_LEN) 47 | year_indices = torch.load("cache/logit_indices.pt") 48 | ld_circuit, group = load_diff_model(circuit, year_indices, ds.good_mask, device=DEVICE) 49 | 50 | #%% 51 | def se(c): 52 | """Short function for Sample and Evaluate along the global variable `group`""" 53 | transform = rc.Sampler(rc.RunDiscreteVarAllSpec([group])) 54 | return transform.sample(c).evaluate() 55 | 56 | 57 | def sec(c): 58 | """Short function for Sample and Evaluate along the global variable `group`""" 59 | transform = rc.Sampler(rc.RunDiscreteVarAllSpec([group])) 60 | return collate(transform.sample(c).evaluate(), ds.years_YY) 61 | 62 | 63 | # Let's make a copy of the circuit that actually has inputs! 64 | c = ld_circuit.update( 65 | "tokens", 66 | lambda _: rc.DiscreteVar(to_device(rc.Array(ds.good_toks, name="tokens"), DEVICE), probs_and_group=group), 67 | ) 68 | baseline_mean = se(c).mean() 69 | 70 | 71 | #%% 72 | # We need to make an extender factory, and then some matcher extenders to iteratively path patch with 73 | extender_factory = make_extender_factory(MAX_LEN) 74 | end_pos_matcher_extenders = [ 75 | extender_factory(MLPHeadAndPosSpec(l, cast(HeadOrMlpType, h), END_POS), qkv=None) 76 | for l in range(12) 77 | for h in list(range(12)) + ["mlp"] 78 | ] 79 | 80 | 81 | def iterative_path_patch(matchers_to_extend, matcher_extenders, patch_data): 82 | """Calls iterative path patching, keeping the baseline / patch data, group, input_name, and output_name constant""" 83 | return ( 84 | iterative_path_patching_nocorr( 85 | circuit=ld_circuit, 86 | matchers_to_extend=matchers_to_extend, 87 | baseline_data=ds.good_toks, 88 | patch_data=patch_data, 89 | group=group, 90 | matcher_extenders=matcher_extenders, 91 | input_name="tokens", 92 | output_shape=(12, 13, -1), 93 | ).mean(-1) 94 | ) - baseline_mean 95 | 96 | 97 | #%% 98 | m11_extender = extender_factory(MLPHeadAndPosSpec(11, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 99 | m11_matcher = m11_extender(corr_root_matcher) 100 | m10_extender = extender_factory(MLPHeadAndPosSpec(10, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 101 | m10_matcher = m10_extender(corr_root_matcher | m11_matcher) 102 | m9_extender = extender_factory(MLPHeadAndPosSpec(9, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 103 | m9_matcher = m9_extender(corr_root_matcher | m11_matcher | m10_matcher) 104 | m8_extender = extender_factory(MLPHeadAndPosSpec(8, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 105 | m8_matcher = m8_extender(corr_root_matcher | m11_matcher | m10_matcher | m9_matcher) 106 | 107 | 108 | mlp_set_extender = extender_factory( 109 | {MLPHeadAndPosSpec(i, cast(HeadOrMlpType, "mlp"), END_POS) for i in range(10,11)}, qkv=None 110 | ) 111 | 112 | heads2 = [(9, 2), (8, 0), (7, 11), (6, 10), (5, 6), (8, 9), (5, 2)] 113 | heads_orig = [(9, 1), (8, 11), (7, 10), (6, 9), (5, 5), (8, 8), (5, 1)] 114 | heads = heads_orig 115 | 116 | attention_set_extenders = extender_factory( 117 | { 118 | MLPHeadAndPosSpec(layer, head, END_POS) 119 | for layer, head in heads 120 | } 121 | ) 122 | running = corr_root_matcher 123 | ms = attention_set_extenders(corr_root_matcher) 124 | for i in range(4): 125 | running = mlp_set_extender(running) 126 | ms = ms | attention_set_extenders(running) 127 | 128 | 129 | patched_circuit = path_patching( 130 | ld_circuit, 131 | ds.bad_toks, # unpatched nodes get bad data 132 | ds.good_toks, # patched ndoes get good data 133 | ms, 134 | group, 135 | "tokens", 136 | ) 137 | 138 | patched_results = se(patched_circuit).mean() 139 | ms_patched_results = patched_results 140 | print(patched_results, baseline_mean, patched_results / baseline_mean) 141 | 142 | 143 | probs = torch.softmax(sec(patched_circuit.get_unique("logits")), dim=-1)[:, -1, year_indices] 144 | fig = show_diffs( 145 | probs, 146 | center_zero=False, 147 | zrange=(0.0, 0.25), 148 | title="GPT-2 Small Probability Heatmap (Patched)", 149 | zlabel="probability", 150 | color_continuous_scale="amp", 151 | ) 152 | fig.show() 153 | 154 | 155 | # %% 156 | -------------------------------------------------------------------------------- /neuron_plots.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import torch 3 | import plotly.express as px 4 | from plotly.subplots import make_subplots 5 | import plotly.graph_objects as go 6 | 7 | # %% 8 | outer_products = torch.load("paper-cache/t10-logitlens.pt") 9 | direct_effects = torch.load("paper-cache/t10-direct_effects.pt") 10 | patched_logit_diff = torch.load("paper-cache/t10-t10patched.pt") 11 | order = torch.load("cache/order.pt") 12 | #%% 13 | top3_fig = make_subplots( 14 | rows=1, 15 | cols=3, 16 | shared_xaxes=False, 17 | shared_yaxes=True, 18 | subplot_titles=[f"MLP 10 Neuron {i}" for i in order[:3]], 19 | horizontal_spacing=0.02, 20 | vertical_spacing=0.05, 21 | ) 22 | top3_fig.add_trace(go.Heatmap(z=outer_products[0].cpu(), coloraxis="coloraxis1"), row=1, col=1) 23 | top3_fig.add_trace(go.Heatmap(z=outer_products[1].cpu(), coloraxis="coloraxis1"), row=1, col=2) 24 | top3_fig.add_trace(go.Heatmap(z=outer_products[2].cpu(), coloraxis="coloraxis1"), row=1, col=3) 25 | 26 | top3_fig.update_layout( 27 | width=1000, 28 | height=350, 29 | coloraxis1=dict( 30 | colorscale="RdBu", 31 | colorbar_x=1.007, 32 | colorbar_thickness=23, 33 | colorbar_title="Logit Lens Magnitude", 34 | colorbar_title_side="right", 35 | ), 36 | ) 37 | top3_fig.update_yaxes(autorange="reversed") 38 | top3_fig.update_layout( 39 | xaxis=dict(title="Predicted Year"), 40 | xaxis2=dict(title="Predicted Year"), 41 | xaxis3=dict(title="Predicted Year"), 42 | yaxis=dict(title="YY", tickmode="array", tickvals=list(range(0, 98, 5)), ticktext=list(range(2, 99, 5))), 43 | ) 44 | top3_fig.update_layout(margin=dict(l=30, r=30, b=30, t=30), title_x=0.5) 45 | top3_fig.show() 46 | top3_fig.write_image("paper-plots/top3-neurons.pdf") 47 | # %% 48 | top3_fig = make_subplots( 49 | rows=1, 50 | cols=3, 51 | shared_xaxes=False, 52 | shared_yaxes=True, 53 | subplot_titles=[f"MLP 10 Neuron {i}" for i in order[:3]], 54 | horizontal_spacing=0.02, 55 | vertical_spacing=0.05, 56 | ) 57 | top3_fig.add_trace(go.Heatmap(z=direct_effects[0].cpu(), coloraxis="coloraxis1"), row=1, col=1) 58 | top3_fig.add_trace(go.Heatmap(z=direct_effects[1].cpu(), coloraxis="coloraxis1"), row=1, col=2) 59 | top3_fig.add_trace(go.Heatmap(z=direct_effects[2].cpu(), coloraxis="coloraxis1"), row=1, col=3) 60 | 61 | top3_fig.update_layout( 62 | width=1000, 63 | height=350, 64 | coloraxis1=dict( 65 | colorscale="RdBu", 66 | colorbar_x=1.007, 67 | colorbar_thickness=23, 68 | colorbar_title="Logit Change", 69 | colorbar_title_side="right", 70 | ), 71 | ) 72 | top3_fig.update_yaxes(autorange="reversed") 73 | top3_fig.update_layout( 74 | xaxis=dict(title="Predicted Year"), 75 | xaxis2=dict(title="Predicted Year"), 76 | xaxis3=dict(title="Predicted Year"), 77 | yaxis=dict(title="YY", tickmode="array", tickvals=list(range(0, 98, 5)), ticktext=list(range(2, 99, 5))), 78 | ) 79 | top3_fig.update_layout(margin=dict(l=30, r=30, b=30, t=30), title_x=0.5) 80 | top3_fig.show() 81 | top3_fig.write_image("paper-plots/top3-neurons-directeffects.pdf") 82 | # %% 83 | ll_sum_fig = make_subplots( 84 | rows=1, 85 | cols=1, 86 | shared_xaxes=False, 87 | shared_yaxes=True, 88 | subplot_titles=[f"Logit Lens of Top-10 MLP 10 Neurons"], 89 | horizontal_spacing=0.02, 90 | vertical_spacing=0.05, 91 | ) 92 | ll_10 = outer_products[:10].sum(0).cpu() 93 | ll_sum_fig.add_trace(go.Heatmap(z=ll_10 - ll_10[0:1], coloraxis="coloraxis1"), row=1, col=1) 94 | ll_sum_fig.update_layout( 95 | width=450, 96 | height=350, 97 | coloraxis1=dict( 98 | colorscale="RdBu", 99 | colorbar_x=1.007, 100 | colorbar_thickness=23, 101 | colorbar_title="Logit Change", 102 | colorbar_title_side="right", 103 | cmin=-30, 104 | cmax=30, 105 | ), 106 | ) 107 | ll_sum_fig.update_yaxes(autorange="reversed") 108 | ll_sum_fig.update_layout( 109 | xaxis=dict(title="Predicted Year"), 110 | yaxis=dict(title="YY", tickmode="array", tickvals=list(range(0, 98, 5)), ticktext=list(range(2, 99, 5))), 111 | ) 112 | ll_sum_fig.update_layout(margin=dict(l=30, r=30, b=30, t=30), title_x=0.5) 113 | ll_sum_fig.show() 114 | ll_sum_fig.write_image("paper-plots/top10-neurons-logitlenssum.pdf") 115 | # %% 116 | de_fig = make_subplots( 117 | rows=1, 118 | cols=2, 119 | shared_xaxes=False, 120 | shared_yaxes=True, 121 | subplot_titles=[f"{t} Direct Effects" for t in ["Summed", "Patched"]], 122 | horizontal_spacing=0.02, 123 | vertical_spacing=0.05, 124 | ) 125 | de_fig.add_trace(go.Heatmap(z=direct_effects[:10].sum(0).cpu(), coloraxis="coloraxis1"), row=1, col=1) 126 | de_fig.add_trace(go.Heatmap(z=patched_logit_diff.cpu(), coloraxis="coloraxis1"), row=1, col=2) 127 | 128 | de_fig.update_layout( 129 | width=750, 130 | height=350, 131 | coloraxis1=dict( 132 | colorscale="RdBu", 133 | colorbar_x=1.007, 134 | colorbar_thickness=23, 135 | colorbar_title="Logit Change", 136 | colorbar_title_side="right", 137 | cmin=-3, 138 | cmax=3, 139 | ), 140 | ) 141 | de_fig.update_yaxes(autorange="reversed") 142 | de_fig.update_layout( 143 | xaxis=dict(title="Predicted Year"), 144 | xaxis2=dict(title="Predicted Year"), 145 | yaxis=dict(title="YY", tickmode="array", tickvals=list(range(0, 98, 5)), ticktext=list(range(2, 99, 5))), 146 | ) 147 | de_fig.update_layout(margin=dict(l=30, r=30, b=30, t=30), title_x=0.5) 148 | de_fig.show() 149 | de_fig.write_image("paper-plots/top10-neurons-directeffects.pdf") 150 | #%% 151 | top100_fig = make_subplots( 152 | rows=1, 153 | cols=2, 154 | shared_xaxes=False, 155 | shared_yaxes=True, 156 | subplot_titles=[f"Top-{i} MLP 10 Neurons Logit Lens" for i in [100, 200]], 157 | horizontal_spacing=0.02, 158 | vertical_spacing=0.05, 159 | ) 160 | ll_100 = outer_products[:100].sum(0).cpu() 161 | ll_200 = outer_products[:200].sum(0).cpu() 162 | top100_fig.add_trace(go.Heatmap(z=ll_100 - ll_100[0:1], coloraxis="coloraxis1"), row=1, col=1) 163 | top100_fig.add_trace(go.Heatmap(z=ll_200 - ll_200[0:1], coloraxis="coloraxis1"), row=1, col=2) 164 | 165 | top100_fig.update_layout( 166 | width=750, 167 | height=350, 168 | coloraxis1=dict( 169 | colorscale="RdBu", 170 | colorbar_x=1.007, 171 | colorbar_thickness=23, 172 | colorbar_title="Logit Lens Magnitude", 173 | colorbar_title_side="right", 174 | cmin=-45, 175 | cmax=50, 176 | ), 177 | ) 178 | top100_fig.update_yaxes(autorange="reversed") 179 | top100_fig.update_layout( 180 | xaxis=dict(title="Predicted Year"), 181 | xaxis2=dict(title="Predicted Year"), 182 | xaxis3=dict(title="Predicted Year"), 183 | yaxis=dict(title="YY", tickmode="array", tickvals=list(range(0, 98, 5)), ticktext=list(range(2, 99, 5))), 184 | ) 185 | top100_fig.update_layout(margin=dict(l=30, r=30, b=30, t=30), title_x=0.5) 186 | top100_fig.show() 187 | top100_fig.write_image("paper-plots/top100-neurons.pdf") 188 | # %% 189 | y_labels = list(range(2, 99)) 190 | title = None # "MLP Iterative Path Patching" 191 | full_fig = px.imshow( 192 | outer_products[3:11].cpu(), 193 | aspect="equal", 194 | facet_col_spacing=0, 195 | facet_row_spacing=0, 196 | facet_col=0, 197 | facet_col_wrap=4, 198 | labels=dict(x="Predicted Year", y="YY"), 199 | title=title, 200 | # x=x_labels, 201 | y=y_labels, 202 | range_color=(-5.0, 5.0), 203 | zmin=-1.0, 204 | zmax=1.0, 205 | color_continuous_scale="RdBu", 206 | ) 207 | full_fig.update_layout( 208 | margin=dict(l=0, r=0, b=30, t=20), 209 | coloraxis_colorbar_x=1.0, 210 | coloraxis_colorbar_title=f"Logit Lens Magnitude", 211 | coloraxis_colorbar_title_side="right", 212 | ) 213 | 214 | 215 | full_fig.update_layout(title_x=0.5) 216 | for i, label in enumerate([f"MLP 10 Neuron {i}" for i in (order[7:11].tolist() + order[3:7].tolist())]): 217 | full_fig.layout.annotations[i]["text"] = label 218 | full_fig.show() 219 | full_fig.write_image("paper-plots/appendix/neurons.pdf") 220 | -------------------------------------------------------------------------------- /neuron_investigations.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from pathlib import Path 3 | 4 | import torch 5 | import plotly.express as px 6 | import rust_circuit as rc 7 | 8 | from rust_circuit.causal_scrubbing.hypothesis import corr_root_matcher 9 | 10 | from dataset import YearDataset 11 | from utils import ( 12 | load_gpt2_small_circuit, 13 | load_diff_model, 14 | MLPHeadAndPosSpec, 15 | path_patching, 16 | load_and_split_gpt2, 17 | make_extender_factory, 18 | show_diffs, 19 | split_mlp, 20 | replace_inputs, 21 | get_valid_years, 22 | ) 23 | 24 | # %% 25 | DEVICE = "cuda:0" 26 | _, tokenizer, _ = load_gpt2_small_circuit() 27 | 28 | #%% 29 | years_to_sample_from = get_valid_years(tokenizer, 1000, 1900) 30 | N = 490 31 | ds = YearDataset(years_to_sample_from, N, Path("cache/potential_nouns.txt"), ordered=True, device=DEVICE, eos=True) 32 | 33 | MAX_LEN = ds.good_toks.size(-1) 34 | extender_factory = make_extender_factory(MAX_LEN) 35 | END_POS = MAX_LEN - 1 36 | XX1_POS = ds.good_prompt.index("XX1") 37 | YY_POS = ds.good_prompt.index("YY") 38 | last_two_digits = ds.years_YY 39 | 40 | #%% 41 | logit_diff = True 42 | metric = "logit mean" if logit_diff else "prob" 43 | circuit = load_and_split_gpt2(MAX_LEN) 44 | year_indices = torch.load("cache/logit_indices.pt") 45 | ld_circuit, group = load_diff_model(circuit, year_indices, ds.good_mask, logit_diff=logit_diff, device=DEVICE) 46 | # %% 47 | c = replace_inputs(ld_circuit, ds.good_toks, "tokens", corr_root_matcher, group) 48 | d = replace_inputs(ld_circuit, ds.bad_toks, "tokens", corr_root_matcher, group) 49 | 50 | def se(c): 51 | """Short function for Sample and Evaluate along the global variable `group`""" 52 | transform = rc.Sampler(rc.RunDiscreteVarAllSpec([group])) 53 | return transform.sample(c).evaluate() 54 | 55 | 56 | #%% 57 | extender_factory = make_extender_factory(MAX_LEN) 58 | layer = "m10" 59 | direct_layer_extender = extender_factory(MLPHeadAndPosSpec(int(layer[1:]), "mlp", END_POS)) 60 | baseline_logits = se(c.get_unique("year_logits")) 61 | weights = c.get_unique(f"{layer}.w.proj_out").value 62 | unembed_matrix = c.get_unique("t.w.tok_embeds").value 63 | activations = se(c.get_unique(rc.IterativeMatcher(layer).chain(rc.restrict("m.act", term_if_matches=True))))[:, -1] 64 | order = torch.load('cache/order.pt') 65 | direct_effects = [] 66 | outer_products = [] 67 | 68 | # We save the top 11 instead of top 10, because the plot looks nicer with one extra neuron 69 | for i, neuron in enumerate(order[:11]): 70 | # getting effects the outer product way 71 | acts = activations[:, neuron] 72 | logit_lens_weights = torch.einsum("h,lh->l", weights[:, neuron], unembed_matrix[year_indices]) 73 | outer_product = torch.einsum("y,l -> yl", acts, logit_lens_weights) 74 | outer_products.append(outer_product.cpu()) 75 | 76 | # getting results the direct effects way 77 | split_circuit = c.update(layer, lambda node: split_mlp(node, torch.tensor([neuron]))) 78 | patched_circuit = path_patching( 79 | split_circuit, 80 | ds.good_toks, 81 | ds.bad_toks, 82 | direct_layer_extender(corr_root_matcher).chain("m.pre_important"), 83 | group, 84 | "tokens", 85 | ) 86 | patched_logits = se(patched_circuit.get_unique("year_logits")) 87 | logit_diff = baseline_logits - patched_logits - baseline_logits[:, 0:1] + patched_logits[:, 0:1] 88 | direct_effects.append(logit_diff.cpu()) 89 | 90 | outer_products = torch.stack(outer_products) 91 | direct_effects = torch.stack(direct_effects) 92 | torch.save(outer_products,'paper-cache/t10-logitlens.pt') 93 | torch.save(direct_effects,'paper-cache/t10-direct_effects.pt') 94 | #%% 95 | top = 10 96 | extender_factory = make_extender_factory(MAX_LEN) 97 | layer = "m10" 98 | direct_layer_extender = extender_factory(MLPHeadAndPosSpec(int(layer[1:]), "mlp", END_POS)) 99 | baseline_logits = se(c.get_unique("year_logits")) 100 | split_circuit = c.update(layer, lambda node: split_mlp(node, torch.tensor(order[:top]))) 101 | patched_circuit = path_patching( 102 | split_circuit, 103 | ds.good_toks, 104 | ds.bad_toks, 105 | direct_layer_extender(corr_root_matcher).chain("m.pre_important"), 106 | group, 107 | "tokens", 108 | ) 109 | patched_logits = se(patched_circuit.get_unique("year_logits")) 110 | patched_logit_diff = baseline_logits - patched_logits - baseline_logits[:, 0:1] + patched_logits[:, 0:1] 111 | show_diffs( 112 | patched_logit_diff.cpu(), title=f"Top {top} patched direct contributions", center_zero=True 113 | ) 114 | 115 | #%% 116 | torch.save(patched_logit_diff, 'paper-cache/t10-t10patched.pt') 117 | 118 | #%% 119 | from plotly.subplots import make_subplots 120 | import plotly.graph_objects as go 121 | 122 | top2_fig = make_subplots( 123 | rows=1, 124 | cols=2, 125 | shared_xaxes=False, 126 | shared_yaxes=True, 127 | subplot_titles=[f"MLP 10 Neuron {i} Contributions" for i in order[:2]], 128 | horizontal_spacing=0.02, 129 | vertical_spacing=0.05, 130 | ) 131 | top2_fig.add_trace(go.Heatmap(z=direct_effects[0].cpu(), coloraxis="coloraxis1"), row=1, col=1) 132 | top2_fig.add_trace(go.Heatmap(z=direct_effects[0].cpu(), coloraxis="coloraxis1"), row=1, col=2) 133 | 134 | top2_fig.update_layout( 135 | width=1000, 136 | height=400, 137 | coloraxis1=dict( 138 | colorscale="RdBu", 139 | colorbar_x=1.007, 140 | colorbar_thickness=23, 141 | colorbar_title="Logit Change", 142 | colorbar_title_side="right", 143 | ), 144 | ) 145 | top2_fig.update_yaxes(autorange="reversed") 146 | top2_fig.update_layout( 147 | xaxis=dict(title="Predicted Year"), 148 | xaxis2=dict(title="Predicted Year"), 149 | yaxis=dict(title="YY", tickmode="array", tickvals=list(range(0, 98, 5)), ticktext=list(range(2, 99, 5))), 150 | ) 151 | top2_fig.update_layout(margin=dict(l=30, r=30, b=30, t=30), title_x=0.5) 152 | top2_fig.show() 153 | top2_fig.write_image("paper-plots/top2-neurons.pdf") 154 | 155 | #%% 156 | from plotly.subplots import make_subplots 157 | import plotly.graph_objects as go 158 | 159 | summed_logit_diff = outer_products[:10].sum(0) 160 | 161 | top10_fig = make_subplots( 162 | rows=1, 163 | cols=2, 164 | shared_xaxes=False, 165 | shared_yaxes=True, 166 | subplot_titles=[f"MLP 10 Neuron {i} Contributions" for i in order[:2]], 167 | horizontal_spacing=0.02, 168 | vertical_spacing=0.05, 169 | ) 170 | top10_fig.add_trace(go.Heatmap(z=patched_logit_diff.cpu(), coloraxis="coloraxis1"), row=1, col=1) 171 | top10_fig.add_trace(go.Heatmap(z=summed_logit_diff.cpu(), coloraxis="coloraxis1"), row=1, col=2) 172 | 173 | top10_fig.update_layout( 174 | width=1000, 175 | height=400, 176 | coloraxis1=dict( 177 | colorscale="RdBu", 178 | colorbar_x=1.007, 179 | colorbar_thickness=23, 180 | colorbar_title="Logit Change", 181 | colorbar_title_side="right", 182 | cmin=-3, 183 | cmax=3, 184 | ), 185 | ) 186 | top10_fig.update_yaxes(autorange="reversed") 187 | top10_fig.update_layout( 188 | xaxis=dict(title="Predicted Year"), 189 | xaxis2=dict(title="Predicted Year"), 190 | yaxis=dict(title="YY", tickmode="array", tickvals=list(range(0, 98, 5)), ticktext=list(range(2, 99, 5))), 191 | ) 192 | top10_fig.update_layout(margin=dict(l=30, r=30, b=30, t=30), title_x=0.5) 193 | top10_fig.show() 194 | top10_fig.write_image("paper-plots/top10-patched-summed.pdf") 195 | #%% 196 | y_labels = list(range(2,99)) 197 | title = None # "MLP Iterative Path Patching" 198 | full_fig = px.imshow( 199 | outer_products[3:11].cpu(), 200 | aspect="equal", 201 | facet_col_spacing=0, 202 | facet_row_spacing=0, 203 | facet_col=0, 204 | facet_col_wrap=4, 205 | labels=dict(x="Predicted Year", y="YY"), 206 | title=title, 207 | # x=x_labels, 208 | y=y_labels, 209 | range_color=(-5.0, 5.0), 210 | zmin=-1.0, 211 | zmax=1.0, 212 | color_continuous_scale="RdBu", 213 | ) 214 | full_fig.update_layout( 215 | margin=dict(l=0, r=0, b=30, t=20), 216 | coloraxis_colorbar_x=1.0, 217 | ) 218 | 219 | full_fig.update_layout(title_x=0.5) 220 | for i, label in enumerate([f"MLP 10 Neuron {i}" for i in (order[7:11].tolist() + order[3:7].tolist())]): 221 | full_fig.layout.annotations[i]["text"] = label 222 | full_fig.show() 223 | full_fig.write_image("paper-plots/appendix/neurons.pdf") 224 | -------------------------------------------------------------------------------- /big_ds_experiments.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from typing import cast, List 3 | from pathlib import Path 4 | 5 | import torch 6 | import rust_circuit as rc 7 | import matplotlib.pyplot as plt 8 | from tqdm import tqdm 9 | 10 | from rust_circuit.causal_scrubbing.hypothesis import corr_root_matcher 11 | 12 | from dataset import YearDataset 13 | from utils import ( 14 | collate, 15 | HeadOrMlpType, 16 | MLPHeadAndPosSpec, 17 | load_diff_model, 18 | load_gpt2_small_circuit, 19 | path_patching, 20 | to_device, 21 | load_and_split_gpt2, 22 | make_extender_factory, 23 | show_diffs, 24 | get_valid_years, 25 | prob_diff, 26 | cutoff_sharpness, 27 | make_all_nodes_names, 28 | ) 29 | 30 | #%% 31 | # Loading our base model 32 | DEVICE = "cuda:0" 33 | _, tokenizer, _ = load_gpt2_small_circuit() 34 | 35 | #%% 36 | # Creating our dataset 37 | years_to_sample_from = get_valid_years(tokenizer, 1000, 1900) 38 | N = 10000 39 | ds = YearDataset( 40 | years_to_sample_from, 41 | N, 42 | Path("cache/potential_nouns.txt"), 43 | tokenizer, 44 | balanced=False, 45 | eos=True, 46 | device=DEVICE, 47 | ) 48 | 49 | MAX_LEN = ds.good_toks.size(-1) 50 | END_POS = MAX_LEN - 1 51 | XX1_POS = ds.good_prompt.index("XX1") 52 | YY_POS = ds.good_prompt.index("YY") 53 | last_two_digits = ds.years_YY 54 | 55 | #%% 56 | # Splitting our model to make it pretty 57 | metric = "prob" 58 | circuit = load_and_split_gpt2(MAX_LEN) 59 | year_indices = torch.load("cache/logit_indices.pt") 60 | 61 | 62 | def se(c): 63 | """Short function for Sample and Evaluate along the global variable `group`""" 64 | transform = rc.Sampler(rc.RunDiscreteVarAllSpec([group])) 65 | c = c.update("year_logits", lambda x: rc.batch_to_concat(x, axis=0, batch_size=100)) 66 | return transform.sample(c).evaluate().cpu() 67 | 68 | #%% 69 | # manual batching 70 | batch_size = 200 71 | logits = [] 72 | for i in tqdm(range(N // batch_size)): 73 | # Let's make a copy of the circuit that actually has inputs! 74 | ld_circuit, group = load_diff_model( 75 | circuit, year_indices, ds.good_mask[i * batch_size : (i + 1) * batch_size], device=DEVICE 76 | ) 77 | c = ld_circuit.update( 78 | "tokens", 79 | lambda _: rc.DiscreteVar( 80 | to_device(rc.Array(ds.good_toks[i * batch_size : (i + 1) * batch_size], name="tokens"), DEVICE), 81 | probs_and_group=group, 82 | ), 83 | ) 84 | logits.append(se(c.get_unique("logits"))[:, -1].cpu()) 85 | logits = torch.stack(logits).view(N, -1) 86 | #%% 87 | probs = torch.softmax(logits, dim=-1)[:, year_indices] 88 | torch.save(collate(probs, ds.years_YY), "paper-cache/probs.pt") 89 | diffs = prob_diff(probs, ds.years_YY) 90 | print(diffs.mean(), diffs.std()) 91 | 92 | sharpness = cutoff_sharpness(probs, ds.years_YY) 93 | print(sharpness.mean(), sharpness.std()) 94 | #%% 95 | yearwise_probs = collate(probs, ds.years_YY) 96 | 97 | fig = show_diffs( 98 | yearwise_probs, 99 | center_zero=False, 100 | zrange=(0.0, 0.25), 101 | title="GPT-2 Small Probability Heatmap", 102 | zlabel="probability", 103 | color_continuous_scale="amp", 104 | ) 105 | fig.show() 106 | #%% 107 | # Defining the matchers 108 | 109 | 110 | extender_factory = make_extender_factory(MAX_LEN) 111 | end_pos_matcher_extenders = [ 112 | extender_factory(MLPHeadAndPosSpec(l, cast(HeadOrMlpType, h), END_POS), qkv=None) 113 | for l in range(12) 114 | for h in list(range(12)) + ["mlp"] 115 | ] 116 | 117 | mlp_set_extender = extender_factory( 118 | {MLPHeadAndPosSpec(i, cast(HeadOrMlpType, "mlp"), END_POS) for i in range(8, 12)}, qkv=None 119 | ) 120 | attention_set_extenders = extender_factory( 121 | { 122 | MLPHeadAndPosSpec(layer, head, END_POS) 123 | for layer, head in [(9, 1), (8, 11), (7, 10), (6, 9), (5, 5), (8, 8), (5, 1)] 124 | } 125 | ) 126 | running = corr_root_matcher 127 | ms = attention_set_extenders(corr_root_matcher) 128 | for i in range(4): 129 | running = mlp_set_extender(running) 130 | ms = ms | attention_set_extenders(running) 131 | 132 | #%% 133 | batch_size = 200 134 | patched_probs = [] 135 | for i in tqdm(range(N // batch_size)): 136 | # Let's make a copy of the circuit that actually has inputs! 137 | ld_circuit, group = load_diff_model( 138 | circuit, year_indices, ds.good_mask[i * batch_size : (i + 1) * batch_size], device=DEVICE 139 | ) 140 | patched_circuit = path_patching( 141 | ld_circuit, 142 | ds.bad_toks[i * batch_size : (i + 1) * batch_size], # unpatched nodes get bad data 143 | ds.good_toks[i * batch_size : (i + 1) * batch_size], # patched nodes get good data 144 | ms, 145 | group, 146 | "tokens", 147 | ) 148 | patched_probs.append(torch.softmax(se(patched_circuit.get_unique("logits"))[:, -1], dim=-1)[:, year_indices].cpu()) 149 | patched_probs = torch.stack(patched_probs).view(N, -1) 150 | 151 | 152 | torch.save(collate(patched_probs, ds.years_YY), "paper-cache/patched_probs.pt") 153 | #%% 154 | patched_diffs = prob_diff(patched_probs, ds.years_YY) 155 | mean_diffs = 0.817 156 | print(patched_diffs.mean()) 157 | print(patched_diffs.mean() / mean_diffs) 158 | 159 | patched_sharpness = cutoff_sharpness(patched_probs, ds.years_YY) 160 | mean_sharpness = 0.059 161 | print(patched_sharpness.mean()) 162 | print(patched_sharpness.mean() / mean_sharpness) 163 | 164 | #%% 165 | yearwise_patched_probs = [] 166 | for year in range(2, 99): 167 | yearwise_patched_probs.append(patched_probs[ds.years_YY == year].mean(0)) 168 | yearwise_patched_probs = torch.stack(yearwise_patched_probs) 169 | 170 | fig = show_diffs( 171 | yearwise_patched_probs, 172 | center_zero=False, 173 | zrange=(0.0, 0.25), 174 | title="Patched GPT-2 Small Probability Heatmap", 175 | zlabel="probability", 176 | color_continuous_scale="Blues", 177 | ) 178 | fig.update_layout(margin=dict(l=0, r=0, b=10, t=30), title_x=0.5, height=500, width=600) 179 | fig.show() 180 | fig.write_image("paper-plots/patched-probability-heatmap.pdf") 181 | # %% 182 | 183 | reverse_patched_circuit = path_patching( 184 | ld_circuit, 185 | ds.good_toks, # unpatched nodes get bad data 186 | ds.bad_toks, # patched ndoes get good data 187 | ms, 188 | group, 189 | "tokens", 190 | ) 191 | 192 | reverse_patched_results = se(reverse_patched_circuit) 193 | reverse_patched_mean = reverse_patched_results.mean() 194 | print(reverse_patched_mean) 195 | # %% 196 | reverse_patched_probs = torch.softmax(se(reverse_patched_circuit.get_unique("logits")), dim=-1)[:, -1, year_indices] 197 | reverse_patched_diffs = prob_diff(reverse_patched_probs, ds.years_YY) 198 | print(reverse_patched_diffs.mean()) 199 | print(reverse_patched_diffs.mean() / diffs.mean()) 200 | 201 | # %% 202 | def embed_extender(m: rc.IterativeMatcher) -> rc.IterativeMatcher: 203 | return m.chain(rc.new_traversal(start_depth=0, end_depth=1)).chain( 204 | rc.new_traversal(start_depth=1, end_depth=2).chain( 205 | rc.restrict( 206 | rc.Matcher("embeds"), 207 | term_early_at=rc.Matcher(make_all_nodes_names(MAX_LEN)), 208 | term_if_matches=True, 209 | ) 210 | ) 211 | ) 212 | 213 | 214 | lower_extenders = extender_factory( 215 | { 216 | MLPHeadAndPosSpec(layer, head, YY_POS) 217 | for layer, head in [(3, "mlp"), (2, "mlp"), (1, "mlp"), (0, "mlp"), (0, 5), (0, 3), (0, 1)] 218 | } 219 | ) 220 | lower_extenders2 = extender_factory( 221 | { 222 | MLPHeadAndPosSpec(layer, head, YY_POS) 223 | for layer, head in [(3, "mlp"), (2, "mlp"), (1, "mlp"), (0, "mlp"), (0, 1)] 224 | } 225 | ) 226 | running = lower_extenders(ms) 227 | lms = embed_extender(running) # running.chain('embeds') 228 | for i in range(4): 229 | running = lower_extenders2(running) 230 | lms = lms | embed_extender(running) # running.chain('embeds') 231 | lms = lms | ms.chain(rc.restrict({"a.q"}, term_if_matches=True, end_depth=8)) 232 | 233 | #%% 234 | batch_size = 400 235 | patched_probs = [] 236 | for i in tqdm(range(N // batch_size)): 237 | # Let's make a copy of the circuit that actually has inputs! 238 | ld_circuit, group = load_diff_model( 239 | circuit, year_indices, ds.good_mask[i * batch_size : (i + 1) * batch_size], device=DEVICE 240 | ) 241 | patched_circuit = path_patching( 242 | ld_circuit, 243 | ds.bad_toks[i * batch_size : (i + 1) * batch_size], # unpatched nodes get bad data 244 | ds.good_toks[i * batch_size : (i + 1) * batch_size], # patched nodes get good data 245 | lms, 246 | group, 247 | "tokens", 248 | ) 249 | patched_probs.append(torch.softmax(se(patched_circuit.get_unique("logits"))[:, -1], dim=-1)[:, year_indices].cpu()) 250 | patched_probs = torch.stack(patched_probs).view(N, -1) 251 | 252 | 253 | torch.save(collate(patched_probs, ds.years_YY), "paper-cache/full_patched_probs.pt") -------------------------------------------------------------------------------- /appendix_plots.py: -------------------------------------------------------------------------------- 1 | #%% 2 | import torch 3 | from plotly.subplots import make_subplots 4 | import plotly.graph_objects as go 5 | 6 | #%% 7 | failure_probs = [torch.load(f"paper-cache/generalization/probs_{option}.pt") for option in range(3, 6)] 8 | 9 | failure_prob_fig = make_subplots( 10 | rows=1, 11 | cols=3, 12 | shared_xaxes=False, 13 | shared_yaxes=True, 14 | subplot_titles=["Probability Heatmap"] * 3, 15 | horizontal_spacing=0.02, 16 | vertical_spacing=0.05, 17 | ) 18 | failure_prob_fig.add_trace(go.Heatmap(z=failure_probs[0].cpu(), coloraxis="coloraxis1"), row=1, col=1) 19 | failure_prob_fig.add_trace(go.Heatmap(z=failure_probs[1].cpu(), coloraxis="coloraxis1"), row=1, col=2) 20 | failure_prob_fig.add_trace(go.Heatmap(z=failure_probs[2].cpu(), coloraxis="coloraxis1"), row=1, col=3) 21 | 22 | failure_prob_fig.update_layout( 23 | width=1400, 24 | height=400, 25 | coloraxis1=dict( 26 | colorscale="Blues", 27 | colorbar_x=1.007, 28 | colorbar_thickness=23, 29 | colorbar_title="Probability", 30 | colorbar_title_side="right", 31 | cmin=0.0, 32 | cmax=0.25, 33 | ), 34 | ) 35 | failure_prob_fig.update_yaxes(autorange="reversed") 36 | failure_prob_fig.update_layout( 37 | xaxis=dict(title="Predicted Year"), 38 | xaxis2=dict(title="Predicted Year"), 39 | xaxis3=dict(title="Predicted Year"), 40 | yaxis=dict(title="YY", tickmode="array", tickvals=list(range(0, 98, 5)), ticktext=list(range(2, 99, 5))), 41 | ) 42 | failure_prob_fig.update_layout(margin=dict(l=30, r=30, b=30, t=30), title_x=0.5) 43 | failure_prob_fig.show() 44 | failure_prob_fig.write_image("paper-plots/appendix/failure-probs.pdf") 45 | # %% 46 | success_probs = [torch.load(f"paper-cache/generalization/probs_{option}.pt") for option in [0,8,2]] 47 | 48 | success_prob_fig = make_subplots( 49 | rows=1, 50 | cols=3, 51 | shared_xaxes=False, 52 | shared_yaxes=True, 53 | subplot_titles=["Probability Heatmap"] * 3, 54 | horizontal_spacing=0.02, 55 | vertical_spacing=0.05, 56 | ) 57 | success_prob_fig.add_trace(go.Heatmap(z=success_probs[0].cpu(), coloraxis="coloraxis1"), row=1, col=1) 58 | success_prob_fig.add_trace(go.Heatmap(z=success_probs[1].cpu(), coloraxis="coloraxis1"), row=1, col=2) 59 | success_prob_fig.add_trace(go.Heatmap(z=success_probs[2].cpu(), coloraxis="coloraxis1"), row=1, col=3) 60 | 61 | success_prob_fig.update_layout( 62 | width=1400, 63 | height=400, 64 | coloraxis1=dict( 65 | colorscale="Blues", 66 | colorbar_x=1.007, 67 | colorbar_thickness=23, 68 | colorbar_title="Probability", 69 | colorbar_title_side="right", 70 | cmin=0.0, 71 | cmax=0.2, 72 | ), 73 | ) 74 | success_prob_fig.update_yaxes(autorange="reversed") 75 | success_prob_fig.update_layout( 76 | xaxis=dict(title="Predicted Year"), 77 | xaxis2=dict(title="Predicted Year"), 78 | xaxis3=dict(title="Predicted Year"), 79 | yaxis=dict(title="YY", tickmode="array", tickvals=list(range(0, 98, 5)), ticktext=list(range(2, 99, 5))), 80 | ) 81 | success_prob_fig.update_layout(margin=dict(l=30, r=30, b=30, t=30), title_x=0.5) 82 | success_prob_fig.show() 83 | success_prob_fig.write_image("paper-plots/appendix/success-probs.pdf") 84 | # %% 85 | bc_probs = [torch.load(f"paper-cache/generalization/probs_{option}.pt") for option in range(6, 8)] 86 | 87 | bc_prob_fig = make_subplots( 88 | rows=1, 89 | cols=2, 90 | shared_xaxes=False, 91 | shared_yaxes=True, 92 | subplot_titles=["Probability Heatmap"] * 2, 93 | horizontal_spacing=0.02, 94 | vertical_spacing=0.05, 95 | ) 96 | bc_prob_fig.add_trace(go.Heatmap(z=bc_probs[0].cpu(), coloraxis="coloraxis1"), row=1, col=1) 97 | bc_prob_fig.add_trace(go.Heatmap(z=bc_probs[1].cpu(), coloraxis="coloraxis1"), row=1, col=2) 98 | 99 | bc_prob_fig.update_layout( 100 | width=1000, 101 | height=400, 102 | coloraxis1=dict( 103 | colorscale="Blues", 104 | colorbar_x=1.007, 105 | colorbar_thickness=23, 106 | colorbar_title="Probability", 107 | colorbar_title_side="right", 108 | cmin=0.0, 109 | cmax=0.2, 110 | ), 111 | ) 112 | bc_prob_fig.update_yaxes(autorange="reversed") 113 | bc_prob_fig.update_layout( 114 | xaxis=dict(title="Predicted Year"), 115 | xaxis2=dict(title="Predicted Year"), 116 | yaxis=dict(title="YY", tickmode="array", tickvals=list(range(0, 98, 5)), ticktext=list(range(2, 99, 5))), 117 | ) 118 | bc_prob_fig.update_layout(margin=dict(l=30, r=30, b=30, t=30), title_x=0.5) 119 | bc_prob_fig.show() 120 | bc_prob_fig.write_image("paper-plots/appendix/bc-probs.pdf") 121 | # %% 122 | success_ipps = [torch.load(f"paper-cache/generalization/ipp_{option}.pt") for option in [0,8,2]] 123 | 124 | success_ipp_fig = make_subplots( 125 | rows=1, 126 | cols=3, 127 | shared_xaxes=False, 128 | shared_yaxes=True, 129 | subplot_titles=["Iterative Path Patching: Logits"] * 3, 130 | horizontal_spacing=0.02, 131 | vertical_spacing=0.05, 132 | ) 133 | success_ipp_fig.add_trace(go.Heatmap(z=success_ipps[0].cpu(), coloraxis="coloraxis1"), row=1, col=1) 134 | success_ipp_fig.add_trace(go.Heatmap(z=success_ipps[1].cpu(), coloraxis="coloraxis1"), row=1, col=2) 135 | success_ipp_fig.add_trace(go.Heatmap(z=success_ipps[2].cpu(), coloraxis="coloraxis1"), row=1, col=3) 136 | 137 | success_ipp_fig.update_layout( 138 | width=1400, 139 | height=400, 140 | coloraxis1=dict( 141 | colorscale="RdBu", 142 | colorbar_x=1.007, 143 | colorbar_thickness=23, 144 | colorbar_title=f"prob diff variation", 145 | colorbar_title_side="right", 146 | cmin=-0.5, 147 | cmax=0.5, 148 | ), 149 | ) 150 | 151 | 152 | success_ipp_fig.update_yaxes(autorange="reversed") 153 | success_ipp_fig.update_layout( 154 | xaxis=dict(title="Head", tickmode="array", tickvals=list(range(13)), ticktext=list(range(12)) + ["mlp"]), 155 | xaxis2=dict(title="Head", tickmode="array", tickvals=list(range(13)), ticktext=list(range(12)) + ["mlp"]), 156 | xaxis3=dict(title="Head", tickmode="array", tickvals=list(range(13)), ticktext=list(range(12)) + ["mlp"]), 157 | yaxis=dict(title="Layer"), 158 | ) 159 | success_ipp_fig.update_layout(margin=dict(l=30, r=30, b=30, t=30), title_x=0.5) 160 | success_ipp_fig.show() 161 | success_ipp_fig.write_image("paper-plots/appendix/success-ipps.pdf") 162 | 163 | # %% 164 | bc_ipps = [torch.load(f"paper-cache/generalization/ipp_{option}.pt") for option in range(6, 8)] 165 | 166 | bc_ipp_fig = make_subplots( 167 | rows=1, 168 | cols=2, 169 | shared_xaxes=False, 170 | shared_yaxes=True, 171 | subplot_titles=["Iterative Path Patching: Logits"] * 2, 172 | horizontal_spacing=0.02, 173 | vertical_spacing=0.05, 174 | ) 175 | bc_ipp_fig.add_trace(go.Heatmap(z=bc_ipps[0].cpu(), coloraxis="coloraxis1"), row=1, col=1) 176 | bc_ipp_fig.add_trace(go.Heatmap(z=bc_ipps[1].cpu(), coloraxis="coloraxis1"), row=1, col=2) 177 | 178 | bc_ipp_fig.update_layout( 179 | width=1000, 180 | height=400, 181 | coloraxis1=dict( 182 | colorscale="RdBu", 183 | colorbar_x=1.007, 184 | colorbar_thickness=23, 185 | colorbar_title=f"prob diff variation", 186 | colorbar_title_side="right", 187 | cmin=-0.25, 188 | cmax=0.25, 189 | ), 190 | ) 191 | 192 | 193 | bc_ipp_fig.update_yaxes(autorange="reversed") 194 | bc_ipp_fig.update_layout( 195 | xaxis=dict(title="Head", tickmode="array", tickvals=list(range(13)), ticktext=list(range(12)) + ["mlp"]), 196 | xaxis2=dict(title="Head", tickmode="array", tickvals=list(range(13)), ticktext=list(range(12)) + ["mlp"]), 197 | yaxis=dict(title="Layer"), 198 | ) 199 | bc_ipp_fig.update_layout(margin=dict(l=30, r=30, b=30, t=30), title_x=0.5) 200 | bc_ipp_fig.show() 201 | bc_ipp_fig.write_image("paper-plots/appendix/bc-ipps.pdf") 202 | # %% 203 | m7m8_ipps = [torch.load(f"paper-cache/generalization/m{i}_2.pt") for i in range(7, 9)] 204 | 205 | m7m8_ipp_fig = make_subplots( 206 | rows=1, 207 | cols=2, 208 | shared_xaxes=False, 209 | shared_yaxes=True, 210 | subplot_titles=[f"Iterative Path Patching: MLP {i}" for i in [7, 8]], 211 | horizontal_spacing=0.02, 212 | vertical_spacing=0.05, 213 | ) 214 | m7m8_ipp_fig.add_trace(go.Heatmap(z=m7m8_ipps[0].cpu(), coloraxis="coloraxis1"), row=1, col=1) 215 | m7m8_ipp_fig.add_trace(go.Heatmap(z=m7m8_ipps[1].cpu(), coloraxis="coloraxis1"), row=1, col=2) 216 | 217 | m7m8_ipp_fig.update_layout( 218 | width=1000, 219 | height=400, 220 | coloraxis1=dict( 221 | colorscale="RdBu", 222 | colorbar_x=1.007, 223 | colorbar_thickness=23, 224 | colorbar_title=f"prob diff variation", 225 | colorbar_title_side="right", 226 | cmin=-0.05, 227 | cmax=0.05, 228 | ), 229 | ) 230 | 231 | 232 | m7m8_ipp_fig.update_yaxes(autorange="reversed") 233 | m7m8_ipp_fig.update_layout( 234 | xaxis=dict(title="Head", tickmode="array", tickvals=list(range(13)), ticktext=list(range(12)) + ["mlp"]), 235 | xaxis2=dict(title="Head", tickmode="array", tickvals=list(range(13)), ticktext=list(range(12)) + ["mlp"]), 236 | yaxis=dict(title="Layer"), 237 | ) 238 | m7m8_ipp_fig.update_layout(margin=dict(l=30, r=30, b=30, t=30), title_x=0.5) 239 | m7m8_ipp_fig.show() 240 | m7m8_ipp_fig.write_image("paper-plots/appendix/m7m8-ipps.pdf") 241 | # %% 242 | -------------------------------------------------------------------------------- /circuit_discovery_plotting.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from typing import cast, List 3 | from pathlib import Path 4 | 5 | import torch 6 | import matplotlib.pyplot as plt 7 | 8 | from plotly.subplots import make_subplots 9 | import plotly.express as px 10 | import plotly.graph_objects as go 11 | 12 | from utils import show_diffs 13 | 14 | # Loading our base model 15 | logit_diff = False 16 | metric = "logit mean" if logit_diff else "prob" 17 | 18 | #%% 19 | probs = torch.load("paper-cache/probs.pt") 20 | fig = show_diffs( 21 | probs, 22 | center_zero=False, 23 | zrange=(0.0, 0.25), 24 | title="GPT-2 Small Probability Heatmap", 25 | zlabel="probability", 26 | color_continuous_scale="Blues", 27 | ) 28 | 29 | fig.update_layout(margin=dict(l=0, r=0, b=10, t=30), title_x=0.5, height=500, width=600) 30 | fig.show() 31 | fig.write_image("paper-plots/probability-heatmap.pdf") 32 | 33 | i = 39 34 | plt.plot(probs[i].cpu()) 35 | plt.title(f"GPT-2 Probabilities when YY={i + 2}") 36 | plt.xlabel(f"Predicted Year") 37 | plt.ylabel(f"probability") 38 | plt.savefig("paper-plots/probability-slice.pdf", bbox_inches="tight") 39 | 40 | 41 | #%% 42 | def make_ipp_plot(results, title, cmin=-2.0, cmax=2.0): 43 | fig = make_subplots( 44 | subplot_titles=[title], 45 | ) 46 | 47 | fig.add_trace(go.Heatmap(z=results, coloraxis="coloraxis1"), row=1, col=1) 48 | fig.update_layout( 49 | width=450, 50 | height=350, 51 | coloraxis1=dict( 52 | colorscale="RdBu", 53 | colorbar_x=1.007, 54 | colorbar_thickness=23, 55 | colorbar_title=f"{metric} diff variation", 56 | colorbar_title_side="right", 57 | cmin=cmin, 58 | cmax=cmax, 59 | ), 60 | ) 61 | fig.update_yaxes(autorange="reversed") 62 | fig.update_layout( 63 | xaxis=dict(title="Head", tickmode="array", tickvals=list(range(13)), ticktext=list(range(12)) + ["mlp"]), 64 | yaxis=dict(title="Layer"), 65 | ) 66 | fig.update_layout(margin=dict(l=30, r=50, b=30, t=30), title_x=0.5) 67 | return fig 68 | 69 | 70 | results_logits = torch.load("paper-cache/ipp_logits.pt").cpu() 71 | results_11, results_10, results_9, results_8 = (torch.load(f"paper-cache/ipp_mlp{i}.pt").cpu() for i in [11, 10, 9, 8]) 72 | 73 | fig_logits = make_ipp_plot(results_logits, "IPP: Direct Contributions to the Logits", -0.5, 0.5) 74 | fig_logits.show() 75 | fig_logits.write_image("paper-plots/logits-ipp.pdf") 76 | 77 | 78 | fig_11 = make_ipp_plot(results_11, "m11", cmin=-0.1, cmax=0.1) 79 | fig_11.show() 80 | fig_11.write_image("paper-plots/appendix/m11-ipp.pdf") 81 | 82 | fig_10 = make_ipp_plot(results_10, "m10", cmin=-0.2, cmax=0.2) 83 | fig_10.show() 84 | fig_10.write_image("paper-plots/appendix/m10-ipp.pdf") 85 | 86 | fig_9 = make_ipp_plot(results_9, "m9", cmin=-0.2, cmax=0.2) 87 | fig_9.show() 88 | fig_9.write_image("paper-plots/appendix/m9-ipp.pdf") 89 | 90 | fig_8 = make_ipp_plot(results_8, "m8", cmin=-0.1, cmax=0.1) 91 | fig_8.show() 92 | fig_8.write_image("paper-plots/appendix/m8-ipp.pdf") 93 | 94 | x_labels = [f"h{i}" for i in range(12)] + ["mlp"] 95 | y_labels = list(range(12)) 96 | title = "IPP: Direct Contributions via MLPs" # "MLP Iterative Path Patching" 97 | full_fig = px.imshow( 98 | torch.stack([results_11, results_10, results_9, results_8]).cpu(), 99 | aspect="equal", 100 | facet_col_spacing=0, 101 | facet_row_spacing=0, 102 | facet_col=0, 103 | facet_col_wrap=2, 104 | labels=dict(x="Head", y="Layer"), 105 | title=title, 106 | x=x_labels, 107 | y=y_labels, 108 | range_color=(-0.2, 0.2), 109 | zmin=-1.0, 110 | zmax=1.0, 111 | color_continuous_scale="RdBu", 112 | ) 113 | full_fig.update_layout( 114 | margin=dict(l=0, r=0, b=30, t=35), 115 | coloraxis_colorbar_title=f"{metric} diff variation", 116 | coloraxis_colorbar_title_side="right", 117 | coloraxis_colorbar_thickness=23, 118 | ) 119 | 120 | full_fig.update_layout(title_x=0.5) 121 | for i, label in enumerate([f"MLP {i}" for i in [9, 8, 11, 10]]): 122 | full_fig.layout.annotations[i]["text"] = label 123 | full_fig.show() 124 | full_fig.write_image("paper-plots/mlps-ipp.pdf") 125 | 126 | 127 | #%% 128 | attn_patterns_7_10 = torch.load("paper-cache/attn_patterns_a7.h10.pt") 129 | attn_patterns_8_11 = torch.load("paper-cache/attn_patterns_a8.h1.pt") 130 | attn_fig = make_subplots( 131 | rows=1, 132 | cols=2, 133 | shared_xaxes=False, 134 | shared_yaxes=False, 135 | subplot_titles=["Mean Attention Pattern for a7.h10", "Mean Attention Pattern for a8.h11"], 136 | horizontal_spacing=0.1, 137 | vertical_spacing=0.15, 138 | ) 139 | attn_fig.add_trace(go.Heatmap(z=attn_patterns_7_10, coloraxis="coloraxis1"), row=1, col=1) 140 | attn_fig.add_trace(go.Heatmap(z=attn_patterns_8_11, coloraxis="coloraxis1"), row=1, col=2) 141 | 142 | attn_fig.update_layout( 143 | width=1000, 144 | height=400, 145 | coloraxis1=dict( 146 | colorscale="Blues", 147 | colorbar_x=1.007, 148 | colorbar_thickness=23, 149 | colorbar_title="Attention", 150 | colorbar_title_side="right", 151 | ), 152 | ) 153 | attn_fig.update_yaxes(autorange="reversed") 154 | prompt = ["", "The", "NOUN", "lasted", "from", "the", "year", "XX1", "YY", "to", "the", "year", "XX2"] 155 | attn_fig.update_layout( 156 | xaxis=dict(title="Key", position=0.5, tickmode="array", tickvals=list(range(len(prompt))), ticktext=prompt), 157 | xaxis2=dict(title="Key", position=0.5, tickmode="array", tickvals=list(range(len(prompt))), ticktext=prompt), 158 | yaxis=dict(title="Query", tickmode="array", tickvals=list(range(len(prompt))), ticktext=prompt), 159 | yaxis2=dict(title=None, tickmode="array", tickvals=list(range(len(prompt))), ticktext=prompt), 160 | ) 161 | 162 | attn_fig.update_layout( 163 | margin=dict(l=30, r=30, b=30, t=30), 164 | title_x=0.5, 165 | ) 166 | attn_fig.show() 167 | attn_fig.write_image("paper-plots/attn-patterns.pdf") 168 | 169 | #%% 170 | ### Logit Lens Attention Heads 171 | logits_a7 = torch.load("paper-cache/logit_lens_a7.h10.pt") 172 | logits_a8 = torch.load("paper-cache/logit_lens_a8.h11.pt") 173 | attn_fig = make_subplots( 174 | rows=1, 175 | cols=2, 176 | shared_xaxes=False, 177 | shared_yaxes=True, 178 | subplot_titles=[f"Logit Lens of a7.h10", "Logit Lens of a8.h11"], 179 | horizontal_spacing=0.02, 180 | vertical_spacing=0.05, 181 | ) 182 | attn_fig.add_trace(go.Heatmap(z=logits_a7, coloraxis="coloraxis1"), row=1, col=1) 183 | attn_fig.add_trace(go.Heatmap(z=logits_a8, coloraxis="coloraxis1"), row=1, col=2) 184 | 185 | attn_fig.update_layout( 186 | width=1000, 187 | height=400, 188 | coloraxis1=dict( 189 | colorscale="RdBu", 190 | colorbar_x=1.007, 191 | colorbar_thickness=23, 192 | colorbar_title="Magnitude in unembedding space", 193 | colorbar_title_side="right", 194 | ), 195 | ) 196 | attn_fig.update_yaxes(autorange="reversed") 197 | attn_fig.update_layout( 198 | xaxis=dict(title="Predicted Year"), 199 | xaxis2=dict(title="Predicted Year"), 200 | yaxis=dict(title="YY", tickmode="array", tickvals=list(range(0, 98, 5)), ticktext=list(range(2, 99, 5))), 201 | ) 202 | attn_fig.update_layout(margin=dict(l=30, r=30, b=30, t=30), title_x=0.5) 203 | attn_fig.show() 204 | attn_fig.write_image("paper-plots/attn-logitlens.pdf") 205 | 206 | 207 | #%% 208 | ### MLP Logit Lens 209 | logit_lens_mlp = [torch.load(f"paper-cache/logit_lens_mlp{i}.pt").cpu() for i in range(8, 12)] 210 | logit_lens_fig = make_subplots( 211 | rows=1, 212 | cols=4, 213 | shared_xaxes=False, 214 | shared_yaxes=True, 215 | subplot_titles=[f"Logit Lens of MLP {i}" for i in [11, 10, 9, 8]], 216 | horizontal_spacing=0.05, 217 | vertical_spacing=0.15, 218 | ) 219 | logit_lens_fig.add_trace( 220 | go.Heatmap(z=logit_lens_mlp[3] - logit_lens_mlp[3][:, 0:1], coloraxis="coloraxis1"), row=1, col=1 221 | ) 222 | logit_lens_fig.add_trace( 223 | go.Heatmap(z=logit_lens_mlp[2] - logit_lens_mlp[2][:, 0:1], coloraxis="coloraxis2"), row=1, col=2 224 | ) 225 | logit_lens_fig.add_trace( 226 | go.Heatmap(z=logit_lens_mlp[1] - logit_lens_mlp[1][:, 0:1], coloraxis="coloraxis3"), row=1, col=3 227 | ) 228 | logit_lens_fig.add_trace( 229 | go.Heatmap(z=logit_lens_mlp[0] - logit_lens_mlp[0][:, 0:1], coloraxis="coloraxis4"), row=1, col=4 230 | ) 231 | 232 | logit_lens_fig.update_layout( 233 | width=2000, 234 | height=400, 235 | coloraxis1=dict( 236 | colorscale="RdBu", 237 | colorbar_x=0.215, 238 | colorbar_thickness=23, 239 | colorbar_title="Magnitude in unembedding space", 240 | colorbar_title_side="right", 241 | cmin=-5, 242 | cmax=4, 243 | ), 244 | coloraxis2=dict( 245 | colorscale="RdBu", 246 | colorbar_x=0.48, 247 | colorbar_thickness=23, 248 | colorbar_title="Magnitude in unembedding space", 249 | colorbar_title_side="right", 250 | cmin=-15, 251 | cmax=15, 252 | ), 253 | coloraxis3=dict( 254 | colorscale="RdBu", 255 | colorbar_x=0.74, 256 | colorbar_thickness=23, 257 | colorbar_title="Magnitude in unembedding space", 258 | colorbar_title_side="right", 259 | ), 260 | coloraxis4=dict( 261 | colorscale="RdBu", 262 | colorbar_x=1.007, 263 | colorbar_thickness=23, 264 | colorbar_title="Magnitude in unembedding space", 265 | colorbar_title_side="right", 266 | cmin=-15, 267 | cmax=15, 268 | ), 269 | ) 270 | 271 | logit_lens_fig.update_layout( 272 | margin=dict(l=30, r=30, b=30, t=30), title_x=0.5, xaxis_title="Predicted Year", yaxis_title="YY" 273 | ) 274 | logit_lens_fig.update_yaxes(autorange="reversed") 275 | logit_lens_fig.update_layout( 276 | xaxis=dict(title="Predicted Year"), 277 | xaxis2=dict(title="Predicted Year"), 278 | xaxis3=dict(title="Predicted Year"), 279 | xaxis4=dict(title="Predicted Year"), 280 | yaxis=dict(title="YY", tickmode="array", tickvals=list(range(0, 98, 5)), ticktext=list(range(2, 99, 5))), 281 | ) 282 | logit_lens_fig.show() 283 | logit_lens_fig.write_image("paper-plots/mlps-logitlens.pdf") 284 | 285 | 286 | #%% 287 | ### Patched prob plot 288 | patched_probs = torch.load("paper-cache/patched_probs.pt").cpu() 289 | fig = show_diffs( 290 | patched_probs, 291 | center_zero=False, 292 | zrange=(0.0, 0.25), 293 | title="Patched GPT-2 Small Probability Heatmap", 294 | zlabel="probability", 295 | color_continuous_scale="Blues", 296 | ) 297 | fig.update_layout(margin=dict(l=0, r=0, b=10, t=30), title_x=0.5, height=500, width=600) 298 | fig.show() 299 | fig.write_image("paper-plots/patched-probability-heatmap.pdf") 300 | #%% 301 | ### Full Patched prob plot 302 | patched_probs = torch.load("paper-cache/full_patched_probs.pt").cpu() 303 | fig = show_diffs( 304 | patched_probs, 305 | center_zero=False, 306 | zrange=(0.0, 0.25), 307 | title="Full-Circuit Patched GPT-2 Small Probability Heatmap", 308 | zlabel="probability", 309 | color_continuous_scale="Blues", 310 | ) 311 | fig.update_layout(margin=dict(l=0, r=0, b=10, t=30), title_x=0.5, height=500, width=600) 312 | fig.show() 313 | fig.write_image("paper-plots/appendix/full-patched-probability-heatmap.pdf") 314 | #%% 315 | attn_v = torch.load("paper-cache/attn_v.pt").cpu() 316 | attn_v_fig = make_ipp_plot(attn_v, "Attention Value", cmin=-0.5, cmax=0.5) 317 | attn_v_fig.show() 318 | attn_v_fig.write_image("paper-plots/appendix/attn_v.pdf") 319 | #%% 320 | ipp_low_mlp = [torch.load(f"paper-cache/results_mlp{i}.pt").cpu() for i in range(0, 4)] 321 | logit_lens_fig = make_subplots( 322 | rows=1, 323 | cols=4, 324 | shared_xaxes=False, 325 | shared_yaxes=True, 326 | subplot_titles=[f"Iterative Path Patching MLP {i}" for i in [3, 2, 1, 0]], 327 | horizontal_spacing=0.01, 328 | vertical_spacing=0.15, 329 | ) 330 | logit_lens_fig.add_trace(go.Heatmap(z=ipp_low_mlp[3], coloraxis="coloraxis1"), row=1, col=1) 331 | logit_lens_fig.add_trace(go.Heatmap(z=ipp_low_mlp[2], coloraxis="coloraxis1"), row=1, col=2) 332 | logit_lens_fig.add_trace(go.Heatmap(z=ipp_low_mlp[1], coloraxis="coloraxis1"), row=1, col=3) 333 | logit_lens_fig.add_trace(go.Heatmap(z=ipp_low_mlp[0], coloraxis="coloraxis1"), row=1, col=4) 334 | 335 | logit_lens_fig.update_layout( 336 | width=2000, 337 | height=400, 338 | coloraxis1=dict( 339 | colorscale="RdBu", 340 | colorbar_x=1.007, 341 | colorbar_thickness=23, 342 | colorbar_title=f"{metric} diff variation", 343 | colorbar_title_side="right", 344 | cmin=-0.5, 345 | cmax=0.5, 346 | ), 347 | ) 348 | 349 | logit_lens_fig.update_layout(margin=dict(l=30, r=30, b=30, t=30), title_x=0.5, xaxis_title="Head", yaxis_title="Layer") 350 | x_labels = [f"h{i}" for i in range(12)] + ["mlp"] 351 | y_labels = list(range(12)) 352 | logit_lens_fig.update_yaxes(autorange="reversed") 353 | logit_lens_fig.update_layout( 354 | xaxis=dict(title="Head", tickmode="array", tickvals=list(range(13)), ticktext=x_labels), 355 | xaxis2=dict(title="Head", tickmode="array", tickvals=list(range(13)), ticktext=x_labels), 356 | xaxis3=dict(title="Head", tickmode="array", tickvals=list(range(13)), ticktext=x_labels), 357 | xaxis4=dict(title="Head", tickmode="array", tickvals=list(range(13)), ticktext=x_labels), 358 | yaxis=dict( 359 | title="Layer", 360 | ), 361 | ) 362 | logit_lens_fig.show() 363 | logit_lens_fig.write_image("paper-plots/appendix/low-mlps-ipp.pdf") 364 | # %% 365 | attention_collated = torch.load('paper-cache/attn_collated.pt')[:, -1, 8] 366 | plt.plot(torch.arange(2, 99), attention_collated) 367 | plt.ylabel("Attention") 368 | plt.xlabel("Input Year (YY)") 369 | plt.xlim(0,100) 370 | plt.title("Attention of a7.h10 (end position) to the YY position") 371 | plt.savefig('paper-plots/appendix/attn_proportion.pdf') 372 | plt.show() 373 | # %% 374 | -------------------------------------------------------------------------------- /sequence_generalization.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from typing import cast 3 | import random 4 | 5 | import torch 6 | import rust_circuit as rc 7 | 8 | from rust_circuit.causal_scrubbing.hypothesis import corr_root_matcher 9 | 10 | from utils import ( 11 | load_gpt2_small_circuit, 12 | HeadOrMlpType, 13 | MLPHeadAndPosSpec, 14 | load_diff_model, 15 | iterative_path_patching_nocorr, 16 | path_patching, 17 | to_device, 18 | load_and_split_gpt2, 19 | show_diffs, 20 | make_extender_factory, 21 | get_valid_years 22 | ) 23 | 24 | #%% 25 | # Loading our base model 26 | DEVICE = "cuda:0" 27 | MODEL_ID = "gelu_12_tied" # aka gpt2 small 28 | _, tokenizer, _ = load_gpt2_small_circuit() 29 | 30 | #%% 31 | """ 32 | Here are the tasks that we discuss: 33 | 0 “The started in the year 17YY and ended in the year 17” 34 | 1 “It was 17YY then. Some years later, it was the year 17” 35 | 2 “1599, 1607, 1633, 1679, 17YY, 17” 36 | 37 | 3 “1799, 1753, 1733, 1701, 16YY, 16” 38 | 4 Exact-answer tasks, e.g. “1599, 1607, 1633, 1679, 17YY, 17” 39 | 5 “17YY is smaller than 17” 40 | 41 | 6 The ended in the year 17YY and started in the year 17” 42 | 7 “The lasted from the year 7YY BC to the year 7” 43 | 8 "The price of that ranges from $ 17YY to $ 17" 44 | 45 | 9 "XXY1, XXY2, XXY3, XXY4, XXY5, XX", where Y1,...,Y5 are randomly sampled 46 | """ 47 | 48 | # Creating our dataset 49 | years = torch.arange(1702, 1799) 50 | last_two_digits = years % 100 51 | 52 | with open("cache/potential_nouns.txt", "r") as f: 53 | noun_list = [line.strip() for line in f] 54 | nouns = random.choices(noun_list, k=len(years)) 55 | 56 | for option in range(10): 57 | gt = True 58 | if option == 0: 59 | sentences = [ 60 | f"<|endoftext|> The {noun} started in the year 17{y:02d} and ended in the year 17" 61 | for noun, y in zip(nouns, last_two_digits) 62 | ] 63 | sentences_01 = [ 64 | f"<|endoftext|> The {noun} started in the year 1701 and ended in the year 17" 65 | for noun, _ in zip(nouns, last_two_digits) 66 | ] 67 | elif option == 1: 68 | sentences = [ 69 | f"<|endoftext|> The {noun} happened in 17{y:02d}. Some years later, it is now the year 17" 70 | for noun, y in zip(nouns, last_two_digits) 71 | ] 72 | sentences_01 = [ 73 | f"<|endoftext|> The {noun} happened in 1701. Some years later and it is now the year 17" 74 | for noun, _ in zip(nouns, last_two_digits) 75 | ] 76 | elif option == 2: 77 | sentences = [f"<|endoftext|> 1599, 1607, 1633, 1679, 17{y:02d}, 17" for y in last_two_digits] 78 | sentences_01 = [f"<|endoftext|> 1599, 1607, 1633, 1679, 1701, 17" for _ in last_two_digits] 79 | elif option == 3: 80 | sentences = [f"<|endoftext|> 1799, 1753, 1733, 1701, 16{y:02d}, 16" for y in last_two_digits] 81 | sentences_01 = ["<|endoftext|> 1799, 1753, 1733, 1701, 1699, 16" for _ in last_two_digits] 82 | gt = False 83 | elif option == 4: 84 | sentences = [] 85 | corrects = [] 86 | for y in years: 87 | i = 2 88 | while (y % 100) % i == 0: 89 | i += 1 90 | sentences.append(f"<|endoftext|> {y-4*i:04d}, {y-3*i:04d}, {y-2*i:04d}, {y-1*i:04d}, {y:04d}, 17") 91 | corrects.append(y + i) 92 | sentences_01 = [f"<|endoftext|> 1693, 1695, 1697, 1699, 1701, 17" for _ in years] 93 | elif option == 5: 94 | sentences = [f"<|endoftext|> 17{y:02d} is smaller than 17" for y in last_two_digits] 95 | sentences_01 = [f"<|endoftext|> 1701 is smaller than 17" for _ in last_two_digits] 96 | elif option == 6: 97 | sentences = [ 98 | f"<|endoftext|> The {noun} ended in the year 17{y:02d} and started in the year 17" 99 | for noun, y in zip(nouns, last_two_digits) 100 | ] 101 | sentences_01 = [ 102 | f"<|endoftext|> The {noun} ended in the year 1799 and started in the year 17" 103 | for noun, _ in zip(nouns, last_two_digits) 104 | ] 105 | gt = False 106 | elif option == 7: 107 | sentences = [ 108 | f"<|endoftext|> The {noun} lasted from the year 7{y:02d} BC to the year 7" 109 | for noun, y in zip(nouns, last_two_digits) 110 | ] 111 | sentences_01 = [ 112 | f"<|endoftext|> The {noun} lasted from the year 799 BC to the year 7" 113 | for noun, _ in zip(nouns, last_two_digits) 114 | ] 115 | for i in [0, 18, 35, 45, 48, 58, 66, 68, 75, 78]: 116 | sentences[i] = sentences[i + 1] 117 | gt = False 118 | elif option == 8: 119 | items = [ 120 | "gem", 121 | "necklace", 122 | "watch", 123 | "ring", 124 | "suitcase", 125 | "scarf", 126 | "suit", 127 | "shirt", 128 | "sweater", 129 | "dress", 130 | "fridge", 131 | "TV", 132 | "bed", 133 | "bike", 134 | "lamp", 135 | "table", 136 | "chair", 137 | "painting", 138 | "sculpture", 139 | "plant", 140 | ] 141 | sentences = [ 142 | f"<|endoftext|> The price of that {item} ranges from $ 17{y:02} to $ 17" 143 | for y, item in zip(last_two_digits, items * 5) 144 | ] 145 | sentences_01 = [ 146 | f"<|endoftext|> The price of that {item} ranges from $ 1701 to $ 17" 147 | for y, item in zip(last_two_digits, items * 5) 148 | ] 149 | elif option == 9: 150 | years = [] 151 | n = len(last_two_digits) 152 | centuries = torch.arange(10,19) * 100 153 | years_XX00 = centuries[torch.randint(len(centuries), (n,))] 154 | years_XX = years_XX00 // 100 155 | years_to_sample_from = get_valid_years(tokenizer, 1000, 1900) 156 | for XX00 in years_XX00: 157 | sample_space = years_to_sample_from[(years_to_sample_from >= XX00) & (years_to_sample_from < XX00+100)] 158 | years.append(sample_space[torch.randint(sample_space.size(0), (5,))]) 159 | years = torch.stack(years) 160 | years_YY = years % 100 161 | last_two_digits = years_YY[:, -1] 162 | sentences = [f'{str(y.tolist())[1:-1]}, {XX}' for y, XX in zip(years, years_XX)] 163 | sentences_01 = [f'{str(y.tolist())[1:-3]}01, {XX}' for y, XX in zip(years, years_XX)] 164 | else: 165 | raise ValueError(f"Bad option given (should be 0 - 10): {option}") 166 | 167 | toks = [tokenizer(sentence, return_tensors="pt")["input_ids"].squeeze() for sentence in sentences] 168 | toks = torch.stack(toks).cuda() 169 | toks_01 = torch.stack( 170 | [tokenizer(sentence, return_tensors="pt")["input_ids"].squeeze() for sentence in sentences_01] 171 | ).cuda() 172 | 173 | MAX_LEN = toks.size(-1) 174 | END_POS = MAX_LEN - 1 175 | 176 | masks = [] 177 | for year in last_two_digits: 178 | if gt: 179 | mask = torch.arange(100) > year 180 | else: 181 | mask = torch.arange(100) < year 182 | masks.append(mask) 183 | 184 | masks = torch.stack(masks) 185 | 186 | # Splitting our model to make it pretty 187 | logit_diff = False 188 | metric = "prob" 189 | circuit = load_and_split_gpt2(MAX_LEN) 190 | year_indices = torch.load("cache/logit_indices.pt") 191 | ld_circuit, group = load_diff_model(circuit, year_indices, masks, logit_diff=logit_diff, device=DEVICE) 192 | 193 | def se(c): 194 | """Short function for Sample and Evaluate along the global variable `group`""" 195 | transform = rc.Sampler(rc.RunDiscreteVarAllSpec([group])) 196 | return transform.sample(c).evaluate() 197 | 198 | # Let's make a copy of the circuit that actually has inputs! 199 | c = ld_circuit.update( 200 | "tokens", 201 | lambda _: rc.DiscreteVar(to_device(rc.Array(toks, name="tokens"), DEVICE), probs_and_group=group), 202 | ) 203 | baseline_mean = se(c).mean() 204 | 205 | probs = torch.softmax(se(c.get_unique("logits")), dim=-1)[:, -1, year_indices] 206 | torch.save(probs, f"paper-cache/generalization/probs_{option}.pt") 207 | 208 | # We need to make an extender factory, and then some matcher extenders to iteratively path patch with 209 | extender_factory = make_extender_factory(MAX_LEN) 210 | end_pos_matcher_extenders = [ 211 | extender_factory(MLPHeadAndPosSpec(l, cast(HeadOrMlpType, h), END_POS), qkv=None) 212 | for l in range(12) 213 | for h in list(range(12)) + ["mlp"] 214 | ] 215 | 216 | def iterative_path_patch(matchers_to_extend, matcher_extenders, patch_data): 217 | """Calls iterative path patching, keeping the baseline / patch data, group, input_name, and output_name constant""" 218 | return ( 219 | iterative_path_patching_nocorr( 220 | circuit=ld_circuit, 221 | matchers_to_extend=matchers_to_extend, 222 | baseline_data=toks, 223 | patch_data=patch_data, 224 | group=group, 225 | matcher_extenders=matcher_extenders, 226 | input_name="tokens", 227 | output_shape=(12, 13, -1), 228 | ).mean(-1) 229 | ) - baseline_mean 230 | 231 | # Let's see what nodes are important, starting from the root, and looking at all MLPs / attention heads 232 | alt_tok_name = "01" 233 | results = iterative_path_patch([corr_root_matcher], end_pos_matcher_extenders, toks_01) 234 | torch.save(results, f"paper-cache/generalization/ipp_{option}.pt") 235 | 236 | if option in {0,2,8,9}: 237 | # Let's see what nodes are important, starting from root->m11, and looking at all MLPs / attention heads 238 | alt_tok_name = "01" 239 | m11_extender = extender_factory(MLPHeadAndPosSpec(11, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 240 | m11_matcher = m11_extender(corr_root_matcher) 241 | 242 | 243 | # Let's see what nodes are important, starting from root->m11->m10, and looking at all MLPs / attention heads 244 | alt_tok_name = "01" 245 | m10_extender = extender_factory(MLPHeadAndPosSpec(10, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 246 | m10_matcher = m10_extender(corr_root_matcher | m11_matcher) 247 | 248 | # Let's see what nodes are important, starting from root->m11->m10->m9, and looking at all MLPs / attention heads 249 | alt_tok_name = "01" 250 | m9_extender = extender_factory(MLPHeadAndPosSpec(9, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 251 | m9_matcher = m9_extender(corr_root_matcher | m11_matcher | m10_matcher) 252 | 253 | 254 | # Let's see what nodes are important, starting from root->m11->m10->m9->m8, and looking at all MLPs / attention heads 255 | alt_tok_name = "01" 256 | m8_extender = extender_factory(MLPHeadAndPosSpec(8, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 257 | m8_matcher = m8_extender(corr_root_matcher | m11_matcher | m10_matcher | m9_matcher) 258 | results = iterative_path_patch([m8_matcher], end_pos_matcher_extenders, toks_01) 259 | torch.save(results, f"paper-cache/generalization/m8_{option}.pt") 260 | 261 | if option == 2: 262 | # Let's see what nodes are important, starting from root->m11->m10->m9->m8, and looking at all MLPs / attention heads 263 | alt_tok_name = "01" 264 | m7_extender = extender_factory(MLPHeadAndPosSpec(7, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 265 | m7_matcher = m7_extender(corr_root_matcher | m11_matcher | m10_matcher | m9_matcher | m8_matcher) 266 | results = iterative_path_patch([m7_matcher], end_pos_matcher_extenders, toks_01) 267 | torch.save(results, f"paper-cache/generalization/m7_{option}.pt") 268 | 269 | extender_factory = make_extender_factory(MAX_LEN) 270 | end_pos_matcher_extenders = [ 271 | extender_factory(MLPHeadAndPosSpec(l, cast(HeadOrMlpType, h), END_POS), qkv=None) 272 | for l in range(12) 273 | for h in list(range(12)) + ["mlp"] 274 | ] 275 | 276 | extra_mlps = [7] if option == 2 else [] 277 | extra_heads = [(7, 11), (6, 1)] if option == 2 else [] 278 | mlp_set_extender = extender_factory( 279 | {MLPHeadAndPosSpec(i, cast(HeadOrMlpType, "mlp"), END_POS) for i in [8, 9, 10, 11] + extra_mlps}, qkv=None 280 | ) 281 | attention_set_extenders = extender_factory( 282 | { 283 | MLPHeadAndPosSpec(layer, head, END_POS) 284 | for layer, head in [(9, 1), (8, 11), (7, 10), (6, 9), (5, 5), (8, 8), (5, 1)] + extra_heads 285 | } 286 | ) 287 | running = corr_root_matcher 288 | ms = attention_set_extenders(corr_root_matcher) 289 | for i in range(4): 290 | running = mlp_set_extender(running) 291 | ms = ms | attention_set_extenders(running) 292 | 293 | patched_circuit = path_patching( 294 | ld_circuit, 295 | toks_01, # unpatched nodes get bad data 296 | toks, # patched ndoes get good data 297 | ms, 298 | group, 299 | "tokens", 300 | ) 301 | 302 | patched_results = se(patched_circuit) 303 | patched_mean = patched_results.mean() 304 | print(patched_mean, baseline_mean, patched_mean / baseline_mean) 305 | probs = torch.softmax(se(patched_circuit.get_unique("logits")), dim=-1)[:, -1, year_indices] 306 | show_diffs(probs, center_zero=False, title="Probability heatmap", color_continuous_scale="Blues").show() 307 | 308 | probs = torch.softmax(se(patched_circuit.get_unique("logits")), dim=-1)[:, -1, year_indices] 309 | show_diffs(probs, center_zero=False, title="Probability heatmap", color_continuous_scale="Blues").show() 310 | -------------------------------------------------------------------------------- /circuit_discovery.py: -------------------------------------------------------------------------------- 1 | #%% 2 | from typing import cast 3 | from pathlib import Path 4 | 5 | import torch 6 | import rust_circuit as rc 7 | import matplotlib.pyplot as plt 8 | 9 | from rust_circuit.ui import cui 10 | from rust_circuit.ui.very_named_tensor import VeryNamedTensor 11 | from rust_circuit.causal_scrubbing.hypothesis import corr_root_matcher 12 | 13 | from dataset import YearDataset 14 | from utils import ( 15 | collate, 16 | HeadOrMlpType, 17 | MLPHeadAndPosSpec, 18 | load_gpt2_small_circuit, 19 | load_diff_model, 20 | iterative_path_patching_nocorr, 21 | path_patching, 22 | to_device, 23 | logit_lens_ln, 24 | load_and_split_gpt2, 25 | show_mtx, 26 | get_attention_pattern, 27 | await_without_await, 28 | show_diffs, 29 | mean_logit_diff, 30 | get_valid_years, 31 | make_all_nodes_names, 32 | make_extender_factory, 33 | cutoff_sharpness, 34 | make_scrubbed_printer, 35 | get_valid_years, 36 | ) 37 | 38 | #%% 39 | # Loading our base model 40 | DEVICE = "cuda:0" 41 | MODEL_ID = "gelu_12_tied" # aka gpt2 small 42 | _, tokenizer, _ = load_gpt2_small_circuit() 43 | 44 | #%% 45 | # Creating our dataset 46 | years_to_sample_from = get_valid_years(tokenizer, 1000, 1900) 47 | N = 490 48 | ds = YearDataset(years_to_sample_from, N, Path("cache/potential_nouns.txt"), tokenizer, balanced=True, device=DEVICE, eos=True) 49 | 50 | MAX_LEN = ds.good_toks.size(-1) 51 | END_POS = MAX_LEN - 1 52 | XX1_POS = ds.good_prompt.index("XX1") 53 | YY_POS = ds.good_prompt.index("YY") 54 | 55 | #%% 56 | # Splitting our model to make it pretty 57 | metric = "prob" 58 | circuit = load_and_split_gpt2(MAX_LEN) 59 | year_indices = torch.load("cache/logit_indices.pt") 60 | ld_circuit, group = load_diff_model(circuit, year_indices, ds.good_mask, device=DEVICE) 61 | 62 | #%% 63 | def se(c): 64 | """Short function for Sample and Evaluate along the global variable `group`""" 65 | transform = rc.Sampler(rc.RunDiscreteVarAllSpec([group])) 66 | return transform.sample(c).evaluate() 67 | 68 | 69 | def sec(c): 70 | """Short function for Sample, Evaluate, and collate along the global variable `group`""" 71 | transform = rc.Sampler(rc.RunDiscreteVarAllSpec([group])) 72 | return collate(transform.sample(c).evaluate(), ds.years_YY) 73 | 74 | 75 | # Let's make a copy of the circuit that actually has inputs! 76 | c = ld_circuit.update( 77 | "tokens", 78 | lambda _: rc.DiscreteVar(to_device(rc.Array(ds.good_toks, name="tokens"), DEVICE), probs_and_group=group), 79 | ) 80 | baseline_mean = se(c).mean() 81 | 82 | #%% 83 | # Let's visualize normal model behavior! 84 | probs = torch.softmax(sec(c.get_unique("logits")), dim=-1)[:, -1, year_indices] 85 | # torch.save(probs, "paper-cache/probs.pt") 86 | fig = show_diffs( 87 | probs, 88 | center_zero=False, 89 | zrange=(0.0, 0.25), 90 | title="GPT-2 Small Probability Heatmap", 91 | zlabel="probability", 92 | color_continuous_scale="amp", 93 | ) 94 | fig.show() 95 | #%% 96 | # let's look at an individual one 97 | i = 39 98 | plt.plot(probs[i].cpu()) 99 | plt.title(f"GPT-2 Probabilities when YY={i + 2}") 100 | plt.xlabel(f"Predicted Year") 101 | plt.ylabel(f"probability") 102 | plt.show() 103 | 104 | #%% 105 | # We need to make an extender factory, and then some matcher extenders to iteratively path patch with 106 | extender_factory = make_extender_factory(MAX_LEN) 107 | end_pos_matcher_extenders = [ 108 | extender_factory(MLPHeadAndPosSpec(l, cast(HeadOrMlpType, h), END_POS), qkv=None) 109 | for l in range(12) 110 | for h in list(range(12)) + ["mlp"] 111 | ] 112 | 113 | 114 | def iterative_path_patch(matchers_to_extend, matcher_extenders, patch_data): 115 | """Calls iterative path patching, keeping the baseline / patch data, group, input_name, and output_name constant""" 116 | return ( 117 | iterative_path_patching_nocorr( 118 | circuit=ld_circuit, 119 | matchers_to_extend=matchers_to_extend, 120 | baseline_data=ds.good_toks, 121 | patch_data=patch_data, 122 | group=group, 123 | matcher_extenders=matcher_extenders, 124 | input_name="tokens", 125 | output_shape=(12, 13, -1), 126 | ).mean(-1) 127 | ) - baseline_mean 128 | 129 | 130 | # %% 131 | # Let's see what nodes are important, starting from the root, and looking at all MLPs / attention heads 132 | 133 | results = iterative_path_patch([corr_root_matcher], end_pos_matcher_extenders, ds.bad_toks) 134 | torch.save(results, "paper-cache/ipp_logits.pt") 135 | show_mtx( 136 | results.cpu(), 137 | title=f"logits", 138 | color_map_label=f"{metric} diff variation", 139 | ) 140 | 141 | # %% 142 | # Let's see what nodes are important, starting from root->m11, and looking at all MLPs / attention heads 143 | m11_extender = extender_factory(MLPHeadAndPosSpec(11, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 144 | m11_matcher = m11_extender(corr_root_matcher) 145 | results = iterative_path_patch([m11_matcher], end_pos_matcher_extenders, ds.bad_toks) 146 | torch.save(results, "paper-cache/ipp_mlp11.pt") 147 | show_mtx( 148 | results.cpu(), 149 | title=f"m11", # f"{metric} diff variation m11 (patch data: {alt_tok_name}-dataset)", 150 | color_map_label=f"{metric} diff variation", 151 | ) 152 | 153 | #%% 154 | # Let's see what nodes are important, starting from root->m11->m10, and looking at all MLPs / attention heads 155 | m10_extender = extender_factory(MLPHeadAndPosSpec(10, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 156 | m10_matcher = m10_extender(corr_root_matcher | m11_matcher) 157 | results = iterative_path_patch([m10_matcher], end_pos_matcher_extenders, ds.bad_toks) 158 | torch.save(results, "paper-cache/ipp_mlp10.pt") 159 | show_mtx( 160 | results.cpu(), 161 | title=f"m10", # f"{metric} diff variation m10 (patch data: {alt_tok_name}-dataset)", 162 | color_map_label=f"{metric} diff variation", 163 | ) 164 | # %% 165 | # Let's see what nodes are important, starting from root->m11->m10->m9, and looking at all MLPs / attention heads 166 | m9_extender = extender_factory(MLPHeadAndPosSpec(9, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 167 | m9_matcher = m9_extender(corr_root_matcher | m11_matcher | m10_matcher) 168 | results = iterative_path_patch([m9_matcher], end_pos_matcher_extenders, ds.bad_toks) 169 | torch.save(results, "paper-cache/ipp_mlp9.pt") 170 | show_mtx( 171 | results.cpu(), 172 | title=f"m9", # f"{metric} diff variation m9 (patch data: {alt_tok_name}-dataset)", 173 | color_map_label=f"{metric} diff variation", 174 | ) 175 | 176 | #%% 177 | # Let's see what nodes are important, starting from root->m11->m10->m9->m8, and looking at all MLPs / attention heads 178 | m8_extender = extender_factory(MLPHeadAndPosSpec(8, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 179 | m8_matcher = m8_extender(corr_root_matcher | m11_matcher | m10_matcher | m9_matcher) 180 | results = iterative_path_patch([m8_matcher], end_pos_matcher_extenders, ds.bad_toks) 181 | torch.save(results, "paper-cache/ipp_mlp8.pt") 182 | show_mtx( 183 | results.cpu(), 184 | title=f"m8", # f"{metric} diff variation m8 (patch data: {alt_tok_name}-dataset)", 185 | color_map_label=f"{metric} diff variation", 186 | ) 187 | 188 | # %% 189 | # What other heads could be important? Let's look at the attention patterns of all heads to find out 190 | # First get the attention patterns for 20 examples 191 | heads = [(i, j) for i in range(12) for j in range(12)] 192 | n_examples = 20 193 | attn_patterns = get_attention_pattern(to_device(circuit, DEVICE), heads, ds.good_toks[:n_examples]).cpu() 194 | 195 | # Then visualize them (maybe take the mean over sentences) 196 | await_without_await(lambda: cui.init(port=6781)) 197 | attn_pattern_vnt = VeryNamedTensor( 198 | attn_patterns, 199 | dim_names="head sentence queries keys".split(), 200 | dim_types="example example axis axis".split(), 201 | dim_idx_names=[ 202 | heads, 203 | [f"seq {i}" for i in range(n_examples)], 204 | ds.good_prompt, 205 | ds.good_prompt, 206 | ], 207 | title="Attention patterns", 208 | ) 209 | await_without_await(lambda: cui.show_tensors(attn_pattern_vnt)) 210 | 211 | #%% 212 | attn_patterns_7_10 = get_attention_pattern(to_device(circuit, DEVICE), [(7, 10)], ds.good_toks).cpu()[0] 213 | attention_collated = collate(attn_patterns_7_10, ds.years_YY) 214 | torch.save(attention_collated, "paper-cache/attn_collated.pt") 215 | mean_attn_patterns_7_10 = attn_patterns_7_10.mean(0) 216 | attn_patterns_8_11 = get_attention_pattern(to_device(circuit, DEVICE), [(8, 11)], ds.good_toks).cpu()[0] 217 | mean_attn_patterns_8_11 = attn_patterns_8_11.mean(0) 218 | torch.save(attn_patterns_7_10, "paper-cache/attn_patterns_a7.h10.pt") 219 | torch.save(attn_patterns_8_11, "paper-cache/attn_patterns_a8.h1.pt") 220 | 221 | #%% 222 | # So what do these heads do? We can examine this question with logit lens (or DPP) 223 | module = "a7.h10" 224 | logits = sec(logit_lens_ln(c, module, device=DEVICE)) 225 | torch.save(logits, "paper-cache/logit_lens_a7.h10.pt") 226 | show_diffs(logits, title=f"Logit lens of {module}") 227 | #%% 228 | module = "a8.h11" 229 | logits = sec(logit_lens_ln(c, module, device=DEVICE)) 230 | torch.save(logits, "paper-cache/logit_lens_a8.h11.pt") 231 | show_diffs(logits, title=f"Logit lens of {module}") 232 | #%% 233 | module = "a9.h1" 234 | logits = sec(logit_lens_ln(c, module, device=DEVICE)) 235 | show_diffs(logits, title=f"Logit lens of {module}") 236 | 237 | #%% 238 | module = "a11.h0" 239 | logits = sec(logit_lens_ln(c, module, device=DEVICE)) 240 | show_diffs(logits - logits[:, 0:1], title=f"Logit lens of {module}") 241 | 242 | 243 | #%% 244 | # So what do these heads do? We can examine this question with logit lens (or DPP) 245 | module = "m8" 246 | logits = sec(logit_lens_ln(c, module, device=DEVICE)) 247 | torch.save(logits, "paper-cache/logit_lens_mlp8.pt") 248 | show_diffs(logits - logits[:, 0:1], title=f"Logit lens of {module}").show() 249 | print(module, "logit mean diff", mean_logit_diff(logits), "cutoff sharpness", cutoff_sharpness(logits).mean()) 250 | #%% 251 | module = "m9" 252 | logits = sec(logit_lens_ln(c, module, device=DEVICE)) 253 | torch.save(logits, "paper-cache/logit_lens_mlp9.pt") 254 | show_diffs(logits - logits[:, 0:1], title=f"Logit lens of {module}").show() 255 | print(module, "logit mean diff", mean_logit_diff(logits), "cutoff sharpness", cutoff_sharpness(logits).mean()) 256 | #%% 257 | module = "m10" 258 | logits = sec(logit_lens_ln(c, module, device=DEVICE)) 259 | torch.save(logits, "paper-cache/logit_lens_mlp10.pt") 260 | show_diffs(logits - logits[:, 0:1], title=f"Logit lens of {module}").show() 261 | print(module, "logit mean diff", mean_logit_diff(logits), "cutoff sharpness", cutoff_sharpness(logits).mean()) 262 | #%% 263 | module = "m11" 264 | logits = sec(logit_lens_ln(c, module, device=DEVICE)) 265 | torch.save(logits, "paper-cache/logit_lens_mlp11.pt") 266 | sd = show_diffs(logits - logits[:, 0:1], title=f"Logit lens of {module}").show() 267 | print(module, "logit mean diff", mean_logit_diff(logits), "cutoff sharpness", cutoff_sharpness(logits).mean()) 268 | 269 | #%% 270 | m11_extender = extender_factory(MLPHeadAndPosSpec(11, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 271 | m11_matcher = m11_extender(corr_root_matcher) 272 | m10_extender = extender_factory(MLPHeadAndPosSpec(10, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 273 | m10_matcher = m10_extender(corr_root_matcher | m11_matcher) 274 | m9_extender = extender_factory(MLPHeadAndPosSpec(9, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 275 | m9_matcher = m9_extender(corr_root_matcher | m11_matcher | m10_matcher) 276 | m8_extender = extender_factory(MLPHeadAndPosSpec(8, cast(HeadOrMlpType, "mlp"), END_POS), qkv=None) 277 | m8_matcher = m8_extender(corr_root_matcher | m11_matcher | m10_matcher | m9_matcher) 278 | 279 | 280 | mlp_set_extender = extender_factory( 281 | {MLPHeadAndPosSpec(i, cast(HeadOrMlpType, "mlp"), END_POS) for i in range(8, 12)}, qkv=None 282 | ) 283 | attention_set_extenders = extender_factory( 284 | { 285 | MLPHeadAndPosSpec(layer, head, END_POS) 286 | for layer, head in [(9, 1), (8, 11), (7, 10), (6, 9), (5, 5), (8, 8), (5, 1)] 287 | } 288 | ) 289 | running = corr_root_matcher 290 | ms = attention_set_extenders(corr_root_matcher) 291 | for i in range(4): 292 | running = mlp_set_extender(running) 293 | ms = ms | attention_set_extenders(running) 294 | #%% 295 | patched_circuit = path_patching( 296 | ld_circuit, 297 | ds.bad_toks, # unpatched nodes get bad data 298 | ds.good_toks, # patched ndoes get good data 299 | ms, 300 | group, 301 | "tokens", 302 | ) 303 | 304 | patched_results = se(patched_circuit).mean() 305 | ms_patched_results = patched_results 306 | print(patched_results, baseline_mean, patched_results / baseline_mean) 307 | 308 | #%% 309 | probs = torch.softmax(sec(patched_circuit.get_unique("logits")), dim=-1)[:, -1, year_indices] 310 | fig = show_diffs( 311 | probs, 312 | center_zero=False, 313 | zrange=(0.0, 0.25), 314 | title="GPT-2 Small Probability Heatmap (Patched)", 315 | zlabel="probability", 316 | color_continuous_scale="amp", 317 | ) 318 | fig.show() 319 | 320 | # %% 321 | # Now, let's find the rest of the circuit! 322 | # what's important to the values of our attention heads 323 | yy_pos_q_matcher_extenders = [ 324 | extender_factory(MLPHeadAndPosSpec(l, cast(HeadOrMlpType, h), YY_POS), qkv="q") 325 | for l in range(12) 326 | for h in list(range(12)) + ["mlp"] 327 | ] 328 | 329 | results = iterative_path_patch([ms], yy_pos_q_matcher_extenders, ds.bad_toks) 330 | torch.save(results, "paper-cache/attn_q.pt") 331 | #%% 332 | show_mtx( 333 | results.cpu(), 334 | title=f"nodes important to attention heads' query vectors", # f"{metric} diff variation m8 (patch data: {alt_tok_name}-dataset)", 335 | color_map_label=f"{metric} diff variation", 336 | ).show() 337 | 338 | # %% 339 | # Now, let's find the rest of the circuit! 340 | # what's important to the values of our attention heads 341 | yy_pos_k_matcher_extenders = [ 342 | extender_factory(MLPHeadAndPosSpec(l, cast(HeadOrMlpType, h), YY_POS), qkv="k") 343 | for l in range(12) 344 | for h in list(range(12)) + ["mlp"] 345 | ] 346 | 347 | results = iterative_path_patch([ms], yy_pos_k_matcher_extenders, ds.bad_toks) 348 | torch.save(results, "paper-cache/attn_k.pt") 349 | #%% 350 | show_mtx( 351 | results.cpu(), 352 | title=f"nodes important to attention heads' key vectors", # f"{metric} diff variation m8 (patch data: {alt_tok_name}-dataset)", 353 | color_map_label=f"{metric} diff variation", 354 | ).show() 355 | # %% 356 | # Now, let's find the rest of the circuit! 357 | # what's important to the values of our attention heads 358 | yy_pos_v_matcher_extenders = [ 359 | extender_factory(MLPHeadAndPosSpec(l, cast(HeadOrMlpType, h), YY_POS), qkv="v") 360 | for l in range(12) 361 | for h in list(range(12)) + ["mlp"] 362 | ] 363 | 364 | results = iterative_path_patch([ms], yy_pos_v_matcher_extenders, ds.bad_toks) 365 | torch.save(results, "paper-cache/attn_v.pt") 366 | #%% 367 | show_mtx( 368 | results.cpu(), 369 | title=f"nodes important to attention heads' value vectors", # f"{metric} diff variation m8 (patch data: {alt_tok_name}-dataset)", 370 | color_map_label=f"{metric} diff variation", 371 | ).show() 372 | 373 | # %% 374 | yy_pos_matcher_extenders = [ 375 | extender_factory(MLPHeadAndPosSpec(l, cast(HeadOrMlpType, h), YY_POS), qkv=None) 376 | for l in range(12) 377 | for h in list(range(12)) + ["mlp"] 378 | ] 379 | 380 | for i in range(4): 381 | mlp_matcher = extender_factory(MLPHeadAndPosSpec(i, "mlp", YY_POS), qkv="v")(ms) 382 | mlp_results = iterative_path_patch([mlp_matcher], yy_pos_matcher_extenders, ds.bad_toks) 383 | torch.save(results, f"paper-cache/results_mlp{i}.pt") 384 | show_mtx( 385 | mlp_results.cpu(), 386 | title=f"nodes important to m{i}", # f"{metric} diff variation m8 (patch data: {alt_tok_name}-dataset)", 387 | color_map_label=f"{metric} diff variation", 388 | ).show() 389 | #%% 390 | def embed_extender(m: rc.IterativeMatcher): 391 | return m.chain(rc.new_traversal(start_depth=0, end_depth=1)).chain( 392 | rc.new_traversal(start_depth=1, end_depth=2).chain( 393 | rc.restrict( 394 | rc.Matcher("embeds"), 395 | term_early_at=rc.Matcher(make_all_nodes_names(MAX_LEN)), 396 | term_if_matches=True, 397 | ) 398 | ) 399 | ) 400 | 401 | lower_extenders = extender_factory( 402 | { 403 | MLPHeadAndPosSpec(layer, head, YY_POS) 404 | for layer, head in [(3, "mlp"), (2, "mlp"), (1, "mlp"), (0, "mlp"), (0, 5), (0, 3), (0, 1)] 405 | } 406 | ) 407 | lower_extenders2 = extender_factory( 408 | { 409 | MLPHeadAndPosSpec(layer, head, YY_POS) 410 | for layer, head in [(3, "mlp"), (2, "mlp"), (1, "mlp"), (0, "mlp"), (0, 1)] 411 | } 412 | ) 413 | running = lower_extenders(ms) 414 | lms = embed_extender(running) #| embed_extender(ms) 415 | for i in range(4): 416 | running = lower_extenders2(running) 417 | lms = lms | embed_extender(running) # running.chain('embeds') 418 | lms = lms | ms.chain(rc.restrict({"a.q"}, term_if_matches=True, end_depth=8)) 419 | #%% 420 | patched_circuit = path_patching( 421 | ld_circuit, 422 | ds.bad_toks, # unpatched nodes get bad data 423 | ds.good_toks, # patched nodes get good data 424 | lms,#whole_circuit_matchers, 425 | group, 426 | "tokens", 427 | ) 428 | printer = make_scrubbed_printer(*patched_circuit.get("tokens")) 429 | 430 | whole_circuit_results = se(patched_circuit) 431 | print(whole_circuit_results.mean()) 432 | print(whole_circuit_results.mean() / se(c).mean()) 433 | print(whole_circuit_results.mean() / ms_patched_results) 434 | # %% 435 | probs = torch.softmax(sec(patched_circuit.get_unique("logits")), dim=-1)[:, -1, year_indices] 436 | fig = show_diffs( 437 | probs, 438 | center_zero=False, 439 | zrange=(0.0, 0.25), 440 | title="GPT-2 Small Probability Heatmap", 441 | zlabel="probability", 442 | color_continuous_scale="amp", 443 | ) 444 | fig.show() 445 | # %% 446 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Optional, Callable, List, Tuple, cast, Any, Set, Union, Literal, NamedTuple 3 | 4 | import torch 5 | import plotly.express as px 6 | import rust_circuit as rc 7 | import numpy as np 8 | from transformers import GPT2TokenizerFast 9 | 10 | 11 | from rust_circuit.py_utils import Slicer as S 12 | from rust_circuit.causal_scrubbing.hypothesis import corr_root_matcher 13 | from rust_circuit.algebric_rewrite import split_to_concat 14 | from rust_circuit.model_rewrites import To, configure_transformer 15 | 16 | year_indices = torch.load("cache/logit_indices.pt") 17 | 18 | HeadOrMlpType = Union[int, Literal["mlp"]] 19 | AttnSuffixForGpt = Union[Literal[""], Literal[".out"]] 20 | 21 | def make_arr( 22 | tokens: torch.Tensor, 23 | name: str, 24 | device_dtype: rc.TorchDeviceDtype = rc.TorchDeviceDtype("cuda:0", "float32"), 25 | ) -> rc.Array: 26 | return rc.cast_circuit(rc.Array(tokens, name=name), device_dtype.op()).cast_array() 27 | 28 | class MLPHeadAndPosSpec(NamedTuple): 29 | layer: int 30 | head_or_mlp: HeadOrMlpType 31 | pos: int 32 | 33 | def to_name(self, attn_suffix_for_bias: str) -> str: 34 | if self.head_or_mlp == "mlp": 35 | return f"m{self.layer}_t{self.pos}" 36 | else: 37 | return f"a{self.layer}{attn_suffix_for_bias}_h{self.head_or_mlp}_t{self.pos}" 38 | 39 | def load_model_path(path: str): 40 | """Load a .circ file. 41 | """ 42 | from rust_circuit.module_library import load_transformer_model_string 43 | 44 | with open(path) as f: 45 | return load_transformer_model_string(f.read()) 46 | 47 | 48 | def load_gpt2_small_circuit(): 49 | return load_model_path("cache/gelu_12_tied.circ") 50 | 51 | 52 | def get_valid_years( 53 | tokenizer: GPT2TokenizerFast, 54 | start: int = 1000, 55 | end: int = 2150, 56 | ): 57 | """Get valid years (_abcd) between [start, end) that are tokenized into 58 | [_ab, cd] by the input tokenizer. Here _ denotes white space. 59 | """ 60 | years = [" " + str(year) for year in range(start, end)] 61 | tokens = tokenizer(years)["input_ids"] 62 | detokenized = [tokenizer.convert_ids_to_tokens(year_toks) for year_toks in tokens] 63 | valid = torch.tensor([(len(detok) == 2 and len(detok[1]) == 2) for detok in detokenized]) 64 | last_valid_index = None 65 | current_century = None 66 | for i, year in zip(range(len(valid)), range(start, end)): 67 | cent = year // 100 68 | if valid[i]: 69 | if current_century != cent: 70 | current_century = cent 71 | valid[i] = False 72 | if last_valid_index is not None: 73 | valid[last_valid_index] = False 74 | last_valid_index = i 75 | if last_valid_index is not None: 76 | valid[last_valid_index] = False 77 | return torch.arange(start, end)[valid] 78 | 79 | 80 | def to_device(c, device): 81 | return rc.cast_circuit(c, device_dtype=rc.TorchDeviceDtypeOp(device=device)) 82 | 83 | 84 | def add_year_mask_to_circuit(c: rc.Circuit, good_mask: torch.Tensor, device: str = "cpu"): 85 | """Run the circuit on all elements of tokens. Assumes the 'tokens' module exists in the circuit.""" 86 | assert good_mask.ndim == 2 and good_mask.shape[1] == 100 87 | batch_size = good_mask.shape[0] 88 | group = rc.DiscreteVar.uniform_probs_and_group(batch_size) 89 | c = c.update( 90 | "good_mask", 91 | lambda _: rc.cast_circuit( 92 | rc.DiscreteVar(rc.Array(good_mask, name="good_mask"), probs_and_group=group), 93 | device_dtype=rc.TorchDeviceDtypeOp(device=device), 94 | ), 95 | ) 96 | return c, group 97 | 98 | 99 | def load_diff_model( 100 | split_circuit: rc.Circuit, 101 | number_indices: torch.Tensor, 102 | good_logits_masks: torch.Tensor, 103 | device="cpu", 104 | ): 105 | """Take GPT2 split by head and position and create a new circuit that is only computing the logit difference. The labels will be embedded in the circuit as a DiscreteVar. The function return the logit diff circuit and the group used by the DiscreteVar to sample the labels.""" 106 | device_dtype = rc.TorchDeviceDtype(dtype="float32", device=device) 107 | 108 | good_logits_masks = good_logits_masks.float() 109 | 110 | good_mask = make_arr( 111 | torch.zeros( 112 | 100, 113 | ).to(device), 114 | "good_mask", 115 | device_dtype=device_dtype, 116 | ) 117 | split_circuit = rc.cast_circuit(split_circuit, device_dtype=rc.TorchDeviceDtypeOp(device=device)) 118 | bad_mask = rc.Add.from_weighted_nodes((rc.Scalar(1.0), 1), (good_mask, -1)) 119 | indices = [-1, number_indices.to(device)] 120 | 121 | probs = rc.softmax(split_circuit) 122 | year_probs = rc.Index(probs, indices, name="year_probs") # rc.softmax(year_logits, name="year_probs") 123 | good_probs = rc.Einsum.from_einsum_string("l->", year_probs.mul(good_mask), name="good_probs") 124 | bad_probs = rc.Einsum.from_einsum_string("l->", year_probs.mul(bad_mask), name="bad_probs") 125 | prob_diff_circuit = to_device( 126 | rc.Add.from_weighted_nodes((good_probs, 1), (bad_probs, -1), name="prob_diff"), device 127 | ) 128 | diff_circuit, group = add_year_mask_to_circuit(prob_diff_circuit, good_logits_masks, device=device) 129 | return rc.cast_circuit(diff_circuit, device_dtype=rc.TorchDeviceDtypeOp(device=device)), group 130 | 131 | 132 | def replace_inputs( 133 | c: rc.Circuit, 134 | x: torch.Tensor, 135 | input_name: str, 136 | m: rc.IterativeMatcher, 137 | group: rc.Circuit, 138 | array_suffix: str = "_array", 139 | ): 140 | """ 141 | Replace the input on the model branch define by the matcher `m` with a DiscreteVar. 142 | The input in the circuit `c` are expected non batched. 143 | """ 144 | assert x.ndim >= 1 145 | c = c.update( 146 | m.chain(input_name), 147 | lambda _: rc.DiscreteVar( 148 | rc.Array(x, name=input_name + array_suffix), 149 | name=input_name, 150 | probs_and_group=group, 151 | ), 152 | ) 153 | return c 154 | 155 | 156 | def path_patching( 157 | circuit: rc.Circuit, 158 | baseline_data: torch.Tensor, 159 | patch_data: torch.Tensor, 160 | matcher: rc.IterativeMatcher, 161 | group: rc.Circuit, 162 | input_name: str, 163 | ) -> rc.Circuit: 164 | baseline_circuit = replace_inputs( 165 | circuit, 166 | baseline_data, 167 | input_name, 168 | corr_root_matcher, 169 | group, 170 | array_suffix="_baseline", 171 | ) 172 | if len(matcher.get(circuit)) == 0: 173 | return baseline_circuit 174 | patched_circuit = replace_inputs( 175 | baseline_circuit, 176 | patch_data, 177 | input_name, 178 | matcher, 179 | group, 180 | array_suffix="_patched", 181 | ) 182 | return patched_circuit 183 | 184 | 185 | def iterative_path_patching_nocorr( 186 | circuit: rc.Circuit, 187 | matchers_to_extend: List[rc.IterativeMatcher], 188 | baseline_data: torch.Tensor, 189 | patch_data: torch.Tensor, 190 | group: rc.Circuit, 191 | matcher_extenders: List[Callable[[rc.IterativeMatcher], rc.IterativeMatcher]], 192 | input_name: str, 193 | output_shape: Optional[Tuple[int, ...]] = None, 194 | ) -> torch.Tensor: 195 | t1 = time.time() 196 | circuits = [] 197 | sampler = rc.Sampler(rc.RunDiscreteVarAllSpec([group])) 198 | nb_not_found = 0 199 | for matcher_extender in matcher_extenders: 200 | matchers_to_h = [] 201 | for matcher in matchers_to_extend: 202 | matchers_to_h.append(matcher_extender(matcher)) 203 | union_matcher = matchers_to_h[0] 204 | 205 | for matcher in matchers_to_h[1:]: 206 | union_matcher = union_matcher | matcher 207 | 208 | if len(union_matcher.get(circuit)) == 0: 209 | nb_not_found += 1 210 | patched_circuit = path_patching(circuit, baseline_data, patch_data, union_matcher, group, input_name) 211 | patched_circuit = sampler(patched_circuit) # we replace discrete vars by the real arrays 212 | circuits.append(patched_circuit) 213 | 214 | if nb_not_found > 0: 215 | print(f"Warning: No match found for {nb_not_found} matcher extenders") 216 | 217 | # a fancy function to evaluate fast many circuit that share tensors in common 218 | results = rc.optimize_and_evaluate_many( 219 | circuits, 220 | rc.OptimizationSettings(scheduling_simplify=False, scheduling_naive=True), 221 | ) 222 | t2 = time.time() 223 | print(f"Time for path patching :{t2 - t1:.2f} s") 224 | if output_shape is None: 225 | return torch.cat([x.unsqueeze(0) for x in results], dim=0) 226 | 227 | return torch.cat(results).reshape(output_shape) 228 | 229 | def collate(results: torch.Tensor, years: torch.Tensor) -> torch.Tensor: 230 | return torch.stack([results[years == y].mean(0) for y in range(2, 99)]) 231 | 232 | def logit_lens_ln_all(circuit: rc.Circuit, matcher: rc.IterativeMatcherIn, device="cpu"): 233 | component = circuit.get_unique(matcher) 234 | logits = circuit.get_unique("logits") 235 | logits_new_input = logits.update("final.input", lambda _: component) 236 | return rc.Index(logits_new_input, [-1], name=f"{circuit.name}_logit_lens_all") 237 | 238 | def logit_lens_ln(circuit: rc.Circuit, matcher: rc.IterativeMatcherIn, device="cpu"): 239 | component = circuit.get_unique(matcher) 240 | logits = circuit.get_unique("logits") 241 | logits_new_input = logits.update("final.input", lambda _: component) 242 | return rc.Index(logits_new_input, [-1, year_indices.to(device)], name=f"{circuit.name}_logit_lens_years") 243 | 244 | 245 | def make_scrubbed_printer(a, b): 246 | def scrub_colorer(c): 247 | getting_scrubbed = c.are_any_found(a) 248 | getting_unscrubbed = c.are_any_found(b) 249 | if getting_scrubbed and getting_unscrubbed: 250 | return "purple" 251 | elif getting_scrubbed: 252 | return "red" 253 | elif getting_unscrubbed: 254 | return "cyan" 255 | else: 256 | return "lightgrey" 257 | 258 | scrubbed_printer = rc.PrintHtmlOptions(shape_only_when_necessary=False, colorer=scrub_colorer) 259 | return scrubbed_printer 260 | 261 | 262 | def load_and_split_gpt2(max_len: int): 263 | circ_dict, tokenizer, model_info = load_gpt2_small_circuit() 264 | unbound_circuit = circ_dict["t.bind_w"] 265 | 266 | tokens_arr = rc.Array(torch.zeros(max_len).to(torch.long), name="tokens") 267 | # We use this to index into the tok_embeds to get the proper embeddings 268 | token_embeds = rc.GeneralFunction.gen_index(circ_dict["t.w.tok_embeds"], tokens_arr, 0, name="tok_embeds") 269 | bound_circuit = model_info.bind_to_input(unbound_circuit, token_embeds, circ_dict["t.w.pos_embeds"]) 270 | 271 | transformed_circuit = bound_circuit.update( 272 | "t.bind_w", 273 | lambda c: configure_transformer( 274 | c, 275 | To.ATTN_HEAD_MLP_NORM, 276 | split_by_head_config="full", 277 | use_pull_up_head_split=True, 278 | use_flatten_res=True, 279 | ), 280 | ) 281 | transformed_circuit = rc.conform_all_modules(transformed_circuit) 282 | 283 | subbed_circuit = transformed_circuit.cast_module().substitute() 284 | subbed_circuit = subbed_circuit.rename("logits") 285 | 286 | def module_but_norm(circuit: rc.Circuit): 287 | if isinstance(circuit, rc.Module): 288 | if "norm" in circuit.name or "ln" in circuit.name or "final" in circuit.name: 289 | return False 290 | else: 291 | return True 292 | return False 293 | 294 | for i in range(100): 295 | subbed_circuit = subbed_circuit.update(module_but_norm, lambda c: c.cast_module().substitute()) 296 | 297 | renamed_circuit = subbed_circuit.update(rc.Regex(r"[am]\d(.h\d)?$"), lambda c: c.rename(c.name + ".inner")) 298 | renamed_circuit = renamed_circuit.update("t.inp_tok_pos", lambda c: c.rename("embeds")) 299 | 300 | for l in range(model_info.params.num_layers): 301 | # b0 -> a1.input, ... b11 -> final.input 302 | next = "final" if l == model_info.params.num_layers - 1 else f"a{l+1}" 303 | renamed_circuit = renamed_circuit.update(f"b{l}", lambda c: c.rename(f"{next}.input")) 304 | 305 | # b0.m -> m0, etc. 306 | renamed_circuit = renamed_circuit.update(f"b{l}.m", lambda c: c.rename(f"m{l}")) 307 | renamed_circuit = renamed_circuit.update(f"b{l}.m.p_bias", lambda c: c.rename(f"m{l}.p_bias")) 308 | renamed_circuit = renamed_circuit.update(f"b{l}.a", lambda c: c.rename(f"a{l}")) 309 | renamed_circuit = renamed_circuit.update(f"b{l}.a.p_bias", lambda c: c.rename(f"a{l}.p_bias")) 310 | 311 | for h in range(model_info.params.num_layers): 312 | # b0.a.h0 -> a0.h0, etc. 313 | renamed_circuit = renamed_circuit.update(f"b{l}.a.h{h}", lambda c: c.rename(f"a{l}.h{h}")) 314 | 315 | head_and_mlp_matcher = rc.IterativeMatcher(rc.Regex(r"^(a\d\d?.h\d\d?|m\d\d?)$")) 316 | partition = range(max_len) 317 | split_circuit = renamed_circuit.update( 318 | head_and_mlp_matcher, 319 | lambda c: split_to_concat(c, axis=0, partitioning_idxs=partition).rename(c.name + "_by_pos"), 320 | ) 321 | 322 | new_names_dict = {} 323 | for l in range(model_info.params.num_layers): 324 | for i in range(max_len): 325 | for h in range(model_info.params.num_layers): 326 | # b0.a.h0 -> a0.h0, etc. 327 | new_names_dict[f"a{l}.h{h}_at_idx_{i}"] = f"a{l}_h{h}_t{i}" 328 | new_names_dict[f"m{l}_at_idx_{i}"] = f"m{l}_t{i}" 329 | 330 | split_circuit = split_circuit.update( 331 | rc.Matcher(*list(new_names_dict.keys())), lambda c: c.rename(new_names_dict[c.name]) 332 | ) 333 | 334 | return split_circuit 335 | 336 | 337 | def show_mtx(mtx, title="NO TITLE :(", color_map_label="Logit diff variation", **kwargs): 338 | """Show a plotly matrix with a centered color map. Designed to display results of path patching experiments.""" 339 | # we center the color scale on zero by defining the range (-max_abs, +max_abs) 340 | max_val = float(max(abs(mtx.min()), abs(mtx.max()))) 341 | x_labels = [f"h{i}" for i in range(12)] + ["mlp"] 342 | fig = px.imshow( 343 | mtx, 344 | title=title, 345 | labels=dict(x="Head", y="Layer", color=color_map_label), 346 | color_continuous_scale="RdBu", 347 | range_color=(-max_val, max_val), 348 | x=x_labels, 349 | y=[str(i) for i in range(mtx.shape[0])], 350 | aspect="equal", 351 | **kwargs 352 | ) 353 | fig.update_coloraxes(colorbar_title_side="right") 354 | return fig 355 | 356 | 357 | def make_all_nodes_names(max_len: int): 358 | ALL_NODES_NAMES = set( 359 | [ 360 | MLPHeadAndPosSpec(l, cast(HeadOrMlpType, h), pos).to_name("") 361 | for l in range(12) 362 | for h in (list(range(12)) + ["mlp"]) # type: ignore 363 | for pos in range(max_len) 364 | ] 365 | ) 366 | return ALL_NODES_NAMES 367 | 368 | 369 | def make_extender_factory(max_len: int): 370 | ALL_NODES_NAMES = make_all_nodes_names(max_len) 371 | 372 | def extender_factory(node: Union[MLPHeadAndPosSpec,Set[MLPHeadAndPosSpec]], qkv: Optional[str] = None): 373 | """ 374 | `qkv` define the input of the attention block we want to reach. 375 | """ 376 | assert qkv in ["q", "k", "v", None] 377 | 378 | if isinstance(node, set): 379 | node_name = {n.to_name("") for n in node} 380 | nodes_to_ban = ALL_NODES_NAMES.difference(node_name) 381 | else: 382 | node_name = node.to_name("") 383 | nodes_to_ban = ALL_NODES_NAMES.difference(set(node_name)) 384 | 385 | if qkv is None: 386 | attn_block_input = rc.new_traversal(start_depth=0, end_depth=1) 387 | else: 388 | attn_block_input = rc.restrict(f"a.{qkv}", term_if_matches=True, end_depth=8) 389 | 390 | def matcher_extender(m: rc.IterativeMatcher): 391 | return m.chain(attn_block_input).chain( 392 | rc.new_traversal(start_depth=1, end_depth=2).chain( 393 | rc.restrict( 394 | rc.Matcher(node_name), 395 | term_early_at=rc.Matcher(nodes_to_ban), 396 | term_if_matches=True, 397 | ) 398 | ) 399 | ) 400 | 401 | return matcher_extender 402 | 403 | return extender_factory 404 | 405 | 406 | def eval_on_toks(c: rc.Circuit, toks: torch.Tensor): 407 | group = rc.DiscreteVar.uniform_probs_and_group(len(toks)) 408 | c = c.update( 409 | "tokens", 410 | lambda _: rc.DiscreteVar(rc.Array(toks, name="tokens"), probs_and_group=group), 411 | ) 412 | transform = rc.Sampler(rc.RunDiscreteVarAllSpec([group])) 413 | results = transform.sample(c).evaluate() 414 | return results 415 | 416 | 417 | def get_attention_pattern( 418 | c: rc.Circuit, 419 | heads: List[Tuple[int, int]], 420 | toks: torch.Tensor, 421 | add_value_weighted=False, 422 | ): 423 | assert toks.ndim == 2 424 | seq_len = toks.shape[1] 425 | attn_patterns = torch.zeros((len(heads), len(toks), seq_len, seq_len)) 426 | 427 | for i, (l, h) in enumerate(heads): 428 | a = rc.Matcher(f"a{l}.h{h}").chain(rc.restrict("a.attn_probs", term_if_matches=True, end_depth=3)) 429 | pattern_circ = a.get_unique(c) 430 | attn = eval_on_toks(pattern_circ, toks) 431 | 432 | if add_value_weighted: 433 | v = rc.Matcher(f"a{l}.h{h}").chain(rc.restrict("a.v_p_bias", term_if_matches=True, end_depth=3)) 434 | values = v.get_unique(c) 435 | vals = eval_on_toks(values, toks) 436 | vals = torch.linalg.norm(vals, dim=-1) 437 | attn_patterns[i] = torch.einsum("bKQ,bK->bKQ", attn, vals) 438 | else: 439 | attn_patterns[i] = attn 440 | 441 | return attn_patterns 442 | 443 | 444 | def await_without_await(func: Callable[[], Any]): 445 | """We want solution files to be usable when run as a script from the command line (where a top level await would 446 | cause a SyntaxError), so we can do CI on the files. Avoiding top-level awaits also lets us use the normal Python 447 | debugger. 448 | Usage: instead of `await cui.init(port=6789)`, write `await_without_await(lambda: cui.init(port=6789))` 449 | """ 450 | try: 451 | while True: 452 | func().send(None) 453 | except StopIteration: 454 | pass 455 | 456 | 457 | def to_numpy(tensor): 458 | """ 459 | Helper function to convert a tensor to a numpy array. Also works on lists, tuples, and numpy arrays. 460 | """ 461 | if isinstance(tensor, np.ndarray): 462 | return tensor 463 | elif isinstance(tensor, (list, tuple)): 464 | array = np.array(tensor) 465 | return array 466 | elif isinstance(tensor, (torch.Tensor, torch.nn.parameter.Parameter)): 467 | return tensor.detach().cpu().numpy() 468 | elif isinstance(tensor, (int, float, bool, str)): 469 | return np.array(tensor) 470 | else: 471 | raise ValueError(f"Input to to_numpy has invalid type: {type(tensor)}") 472 | 473 | 474 | def imshow(tensor, center_zero=True, zrange=None, color_continuous_scale="RdBu", **kwargs): 475 | if center_zero: 476 | return px.imshow( 477 | to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale=color_continuous_scale, **kwargs 478 | ) 479 | elif zrange is not None: 480 | zmin, zmax = zrange 481 | return px.imshow( 482 | to_numpy(tensor), zmin=zmin, zmax=zmax, color_continuous_scale=color_continuous_scale, **kwargs 483 | ) 484 | else: 485 | return px.imshow(to_numpy(tensor), color_continuous_scale=color_continuous_scale, **kwargs) 486 | 487 | 488 | def show_diffs( 489 | diffs, center_zero=True, zrange=None, title="", xlabel="predicted year", zlabel="logit change", dim=500, **kwargs 490 | ): 491 | return imshow( 492 | diffs, 493 | center_zero=center_zero, 494 | zrange=zrange, 495 | height=dim, 496 | width=dim, 497 | title=title, 498 | labels={"x": xlabel, "y": "YY", "color": zlabel}, 499 | y=[str(i) for i in range(2, 99)], 500 | **kwargs, 501 | ) 502 | 503 | 504 | def split_mlp(mlp_node: rc.Circuit, important_neurons: torch.Tensor) -> rc.Circuit: 505 | activations, weights = mlp_node.children 506 | (pre_activations,) = activations.children 507 | device = "cpu" if mlp_node.device is None else mlp_node.device 508 | 509 | unimportant_neurons = torch.tensor(list(set(range(3072)) - set(important_neurons.tolist()))) 510 | if len(important_neurons) == 0: 511 | return mlp_node.update( 512 | rc.restrict("m.pre", end_depth=3), lambda node: rc.Index(node, [S[:], S[:]], name="m.pre_unimportant") 513 | ) 514 | 515 | if len(unimportant_neurons) == 0: 516 | return mlp_node.update( 517 | rc.restrict("m.pre", end_depth=3), lambda node: rc.Index(node, [S[:], S[:]], name="m.pre_important") 518 | ) 519 | 520 | important_neurons = important_neurons.to(device) 521 | unimportant_neurons = unimportant_neurons.to(device) 522 | 523 | important_pre_activations = to_device( 524 | rc.Index(pre_activations, [S[:], important_neurons], name="m.pre_important"), device 525 | ) 526 | unimportant_pre_activations = to_device( 527 | rc.Index(pre_activations, [S[:], unimportant_neurons], name="m.pre_unimportant"), device 528 | ) 529 | 530 | important_activations = rc.gelu(important_pre_activations, name="m.act_important") 531 | 532 | unimportant_activations = rc.gelu(unimportant_pre_activations, name="m.act_unimportant") 533 | 534 | important_weights = to_device( 535 | rc.Index(weights, [S[:], important_neurons], name=f"{weights.name}_important"), device 536 | ) 537 | unimportant_weights = to_device( 538 | rc.Index(weights, [S[:], unimportant_neurons], name=f"{weights.name}_unimportant"), device 539 | ) 540 | 541 | mlp_important = rc.Einsum.from_einsum_string( 542 | "pi,oi->po", important_activations, important_weights, name=f"{mlp_node.name}_important" 543 | ) 544 | mlp_unimportant = rc.Einsum.from_einsum_string( 545 | "pu,ou->po", unimportant_activations, unimportant_weights, name=f"{mlp_node.name}_unimportant" 546 | ) 547 | 548 | mlp_reconstructed = rc.Add(mlp_important, mlp_unimportant, name=mlp_node.name) 549 | return mlp_reconstructed 550 | 551 | #%% 552 | def mean_logit_diff(logits: torch.Tensor, years: torch.Tensor) -> torch.Tensor: 553 | diffs = [] 554 | for logit, year in zip(logits, years): 555 | diffs.append(logit[year + 1 :].sum() - logit[: year + 1].sum()) 556 | return torch.tensor(diffs) 557 | 558 | def cutoff_sharpness(logits: torch.Tensor, years: torch.Tensor = torch.arange(2, 99)) -> torch.Tensor: 559 | sharpness = logits[torch.arange(len(logits)), years + 1] - logits[torch.arange(len(logits)), years - 1] 560 | return sharpness 561 | 562 | def prob_diff(probs: torch.Tensor, years: torch.Tensor) -> torch.Tensor: 563 | diffs = [] 564 | for prob, year in zip(probs, years): 565 | diffs.append(prob[year + 1 :].sum() - prob[: year + 1].sum()) 566 | return torch.tensor(diffs) -------------------------------------------------------------------------------- /cache/gelu_12_tied.circ: -------------------------------------------------------------------------------- 1 | # info:{"params": {"block_params": {"norm_type": "ln", "attn_bias": true, "attn_pos": false, "use_mlp": true, "mlp_act_type": "gelu", "mlp_output_bias": true}, "num_layers": 12, "use_norm_output": true, "output_bias": false}, "model_class": "GPT", "pos_enc_type": "gpt", "causal_mask": true, "extra": null} 2 | 't.w.tok_embeds' [50257,768] Array 1f668798fb95d0e16b2a0143 3 | 't.w.pos_embeds' [1024,768] Array 71f9b465008914ba28cdba38 4 | 't.bind_w' Module 5 | 't.logits' 6 | 'a0.ln.w.bias_arr' [768] Array b8bd8a90d0db28aa0302de9a ! 'a0.ln.w.bias' [0s] Symbol 8d723104-f773-83c1-3458-a748e9bb17bc 7 | 'a0.ln.w.scale_arr' [768] Array 243969425a9334c478301b1c ! 'a0.ln.w.scale' [0s] Symbol 85776e9a-dd84-f39e-7154-5a137a1d5006 8 | 'a0.w.q_arr' [12,64,768] Array 9f6ad6af4526a12cc44d2838 ! 'a0.w.q' [4s,7s,0s] Symbol 1846d424-c17c-6279-23c6-612f48268673 9 | 'a0.w.k_arr' [12,64,768] Array 5d4410f46dbd1d1452ddfa4b ! 'a0.w.k' [4s,7s,0s] Symbol b4862b21-fb97-d435-8856-1712e8e5216a 10 | 'a0.w.v_arr' [12,64,768] Array be3a40d0b97769f7f18e96ce ! 'a0.w.v' [4s,8s,0s] Symbol 12e0c8b2-bad6-40fb-1948-8dec4f65d4d9 11 | 'a0.w.o_arr' [12,768,64] Array ad7282377b6c9eb52bbbed20 ! 'a0.w.o' [4s,0s,8s] Symbol 5a921187-19c7-8df4-8f4f-f31e78de5857 12 | 'a0.w.q_bias_arr' [12,64] Array 703fa86711f4769810c5a6a6 ! 'a0.w.q_bias' [4s,7s] Symbol fcbd04c3-4021-2ef7-cca5-a5a19e4d6e3c 13 | 'a0.w.k_bias_arr' [12,64] Array 3a44ee2fa4c2d5b11da33c29 ! 'a0.w.k_bias' [4s,7s] Symbol 259f4329-e6f4-590b-9a16-4106cf6a659e 14 | 'a0.w.v_bias_arr' [12,64] Array 2b06e681e6c1171295f014cd ! 'a0.w.v_bias' [4s,8s] Symbol 5487ce1e-af19-922a-d9b8-a714e61a441c 15 | 'a0.w.o_bias_arr' [768] Array c6d8301380189e1b9bdb6b46 ! 'a0.w.o_bias' [0s] Symbol a3f2c9bf-9c63-16b9-50f2-44556f25e2a2 16 | 'm0.ln.w.bias_arr' [768] Array c09a18d23a5c16d077c80161 ! 'm0.ln.w.bias' [0s] Symbol e443df78-9558-867f-5ba9-1faf7a024204 17 | 'm0.ln.w.scale_arr' [768] Array 34fa95b940a136416a6e9219 ! 'm0.ln.w.scale' [0s] Symbol 23a7711a-8133-2876-37eb-dcd9e87a1613 18 | 'm0.w.proj_in_arr' [3072,768] Array e0118cea10d34268dc3679c0 ! 'm0.w.proj_in' [5s,0s] Symbol e3e70682-c209-4cac-629f-6fbed82c07cd 19 | 'm0.w.in_bias_arr' [3072] Array f58f9277ab22159514944405 ! 'm0.w.in_bias' [5s] Symbol f728b4fa-4248-5e3a-0a5d-2f346baa9455 20 | 'm0.w.proj_out_arr' [768,3072] Array 1abd3250105101978e11364d ! 'm0.w.proj_out' [0s,5s] Symbol eb1167b3-67a9-c378-7c65-c1e582e2e662 21 | 'm0.w.out_bias_arr' [768] Array f0f95228ce9514bc1519effb ! 'm0.w.out_bias' [0s] Symbol f7c1bd87-4da5-e709-d471-3d60c8a70639 22 | 'a1.ln.w.bias_arr' [768] Array 4401273331a64f1427708c96 ! 'a1.ln.w.bias' [0s] Symbol 8d88348a-7eed-8d14-f06d-3fef701966a0 23 | 'a1.ln.w.scale_arr' [768] Array 36034415540cb05da0f1e6ef ! 'a1.ln.w.scale' [0s] Symbol ad45f23d-3b1a-11df-587f-d2803bab6c39 24 | 'a1.w.q_arr' [12,64,768] Array b38a3866cc416be09171e17d ! 'a1.w.q' [4s,7s,0s] Symbol b2221a58-008a-05a6-c464-7159c324c985 25 | 'a1.w.k_arr' [12,64,768] Array febf65f5b5df27f04074cb00 ! 'a1.w.k' [4s,7s,0s] Symbol 1a2b8f1f-f1fd-42a2-9755-d4c13a902931 26 | 'a1.w.v_arr' [12,64,768] Array c409c61290454fbaa9dca080 ! 'a1.w.v' [4s,8s,0s] Symbol 025b413f-8a9a-021e-a648-a7dd06839eb9 27 | 'a1.w.o_arr' [12,768,64] Array 6bc6eee091efa8d72e611665 ! 'a1.w.o' [4s,0s,8s] Symbol b9d179e0-6c0f-d4f5-f813-0c4237730edf 28 | 'a1.w.q_bias_arr' [12,64] Array 493549a05e19b327ab254879 ! 'a1.w.q_bias' [4s,7s] Symbol cd447e35-b8b6-d8fe-442e-3d437204e52d 29 | 'a1.w.k_bias_arr' [12,64] Array 21da17b902ac8620662d6210 ! 'a1.w.k_bias' [4s,7s] Symbol 05b6e6e3-07d4-bedc-5143-1193e6c3f339 30 | 'a1.w.v_bias_arr' [12,64] Array 01cc3e825579c9d8a8b2de80 ! 'a1.w.v_bias' [4s,8s] Symbol afbd67f9-6196-99cf-e198-8ad9f06c144a 31 | 'a1.w.o_bias_arr' [768] Array 35d279f23c436901c8fd859f ! 'a1.w.o_bias' [0s] Symbol c381e88f-38c0-c8fd-8712-b8bc076f3787 32 | 'm1.ln.w.bias_arr' [768] Array a4e221a6903280b56efaa436 ! 'm1.ln.w.bias' [0s] Symbol e4b06ce6-0741-c7a8-7ce4-2c8218072e8c 33 | 'm1.ln.w.scale_arr' [768] Array b4c2ad28b0d6a9ede4289f0f ! 'm1.ln.w.scale' [0s] Symbol 9b810e76-6ec9-d286-63ca-828dd5f4b3b2 34 | 'm1.w.proj_in_arr' [3072,768] Array d28bceca4271dd30a4de231c ! 'm1.w.proj_in' [5s,0s] Symbol cd613e30-d8f1-6adf-91b7-584a2265b1f5 35 | 'm1.w.in_bias_arr' [3072] Array cc12e95bf2f327d278ae92f6 ! 'm1.w.in_bias' [5s] Symbol 1e2feb89-414c-343c-1027-c4d1c386bbc4 36 | 'm1.w.proj_out_arr' [768,3072] Array 1fbd5d8219a5d212248d9277 ! 'm1.w.proj_out' [0s,5s] Symbol 78e51061-7311-d8a3-c2ce-6f447ed4d57b 37 | 'm1.w.out_bias_arr' [768] Array 8dded2dcd755340938f78ae9 ! 'm1.w.out_bias' [0s] Symbol 35bf992d-c9e9-c616-612e-7696a6cecc1b 38 | 'a2.ln.w.bias_arr' [768] Array 09a6f8ea29df20698246fe9a ! 'a2.ln.w.bias' [0s] Symbol 2d3d854e-061b-9030-3b08-c6e33c729578 39 | 'a2.ln.w.scale_arr' [768] Array 4916ffc15bdde4855028c40e ! 'a2.ln.w.scale' [0s] Symbol 829a48d4-22fe-99a2-2c70-501e533c9135 40 | 'a2.w.q_arr' [12,64,768] Array 6bef9550b0760ef5592df4c1 ! 'a2.w.q' [4s,7s,0s] Symbol cdbd47d3-64be-8049-a372-db8f6e405d93 41 | 'a2.w.k_arr' [12,64,768] Array 6b8b996417fb60680e6eef86 ! 'a2.w.k' [4s,7s,0s] Symbol ef8acd12-8b4f-2fc1-5f3f-57ebf30b94fa 42 | 'a2.w.v_arr' [12,64,768] Array 5986c2f74846096830c8c42e ! 'a2.w.v' [4s,8s,0s] Symbol 5d300cb9-0706-a045-defc-044a09325626 43 | 'a2.w.o_arr' [12,768,64] Array 4200a1480cd47e42656eea83 ! 'a2.w.o' [4s,0s,8s] Symbol e2520e33-e44c-5055-6c71-c4a66148a86f 44 | 'a2.w.q_bias_arr' [12,64] Array 332d00ed691712b7e99c6661 ! 'a2.w.q_bias' [4s,7s] Symbol 82523e86-feac-7eb7-dc38-f519b91751da 45 | 'a2.w.k_bias_arr' [12,64] Array 044bdc2256bebe6041faaed4 ! 'a2.w.k_bias' [4s,7s] Symbol e6b58de7-44ab-6cce-8087-7b6f71e1f6d2 46 | 'a2.w.v_bias_arr' [12,64] Array 5ca5c8c9cab0bd310d34f8b6 ! 'a2.w.v_bias' [4s,8s] Symbol e8624fab-5186-ee32-ee8d-7ee9770348a0 47 | 'a2.w.o_bias_arr' [768] Array f2c41d8cbe4e4fda1e458bd3 ! 'a2.w.o_bias' [0s] Symbol 2d6c797f-8f7d-9b78-2a1b-e9cd8697bbd0 48 | 'm2.ln.w.bias_arr' [768] Array e50bf57c0cbf9c2d350511cb ! 'm2.ln.w.bias' [0s] Symbol 0925e474-9b57-5bd1-3653-f8dd9b1f282e 49 | 'm2.ln.w.scale_arr' [768] Array 31dbb454b7bd0b59782394dc ! 'm2.ln.w.scale' [0s] Symbol ffed9235-288b-c781-ae66-267594c9c950 50 | 'm2.w.proj_in_arr' [3072,768] Array c5c62b5e5f38da64c00930d2 ! 'm2.w.proj_in' [5s,0s] Symbol d95bafc8-f2a4-d27b-dcf4-bb99f4bea973 51 | 'm2.w.in_bias_arr' [3072] Array 38667b94d2de0e788786a1e0 ! 'm2.w.in_bias' [5s] Symbol 5c6e4337-15ba-2bdd-1772-19d30e7a269f 52 | 'm2.w.proj_out_arr' [768,3072] Array 2015dfdc517da2460ef03a05 ! 'm2.w.proj_out' [0s,5s] Symbol cf1822ff-bc68-8778-2b49-1044d5e34124 53 | 'm2.w.out_bias_arr' [768] Array 449ad5d2e5b83d72cf60ad56 ! 'm2.w.out_bias' [0s] Symbol 4067c358-4ee2-07f8-da94-e3e8ab73738f 54 | 'a3.ln.w.bias_arr' [768] Array 0d902e2f4baf9485bd72631e ! 'a3.ln.w.bias' [0s] Symbol 633a50ee-e0f9-e038-eb8f-624fb804d820 55 | 'a3.ln.w.scale_arr' [768] Array e3cb2a676b919107c714b314 ! 'a3.ln.w.scale' [0s] Symbol 6d4b9adb-ebcd-1f5e-c9c1-8070b6d13089 56 | 'a3.w.q_arr' [12,64,768] Array ab4f429927c39832a75de072 ! 'a3.w.q' [4s,7s,0s] Symbol 65aa9c82-79f2-48b0-8cb4-a0d7d6225675 57 | 'a3.w.k_arr' [12,64,768] Array 545764ed653568f1bada31e9 ! 'a3.w.k' [4s,7s,0s] Symbol ed038db4-de38-3784-26d0-b944a2863a7f 58 | 'a3.w.v_arr' [12,64,768] Array 5a79cba85dde2488a8ed0780 ! 'a3.w.v' [4s,8s,0s] Symbol 28ce6f24-1064-5d51-c6f8-da3eabe19f58 59 | 'a3.w.o_arr' [12,768,64] Array 9fb045e6dfca161427d0da88 ! 'a3.w.o' [4s,0s,8s] Symbol d2d58443-07f0-62ce-c7b3-17d94d1fe09f 60 | 'a3.w.q_bias_arr' [12,64] Array 6324f28a458b76c94d2d76ba ! 'a3.w.q_bias' [4s,7s] Symbol 3b5f3d86-268e-cc45-dc6b-f1e1a399f82a 61 | 'a3.w.k_bias_arr' [12,64] Array 5904b0f9b0373e698461dcbe ! 'a3.w.k_bias' [4s,7s] Symbol 03e0a813-bdc2-ae99-63d2-e49085ef3430 62 | 'a3.w.v_bias_arr' [12,64] Array 8a9cc8efafb3e5e6fbdc18b2 ! 'a3.w.v_bias' [4s,8s] Symbol 0af438d2-9752-4d6a-f51e-8722c21b6092 63 | 'a3.w.o_bias_arr' [768] Array a65fb9122213fdaeca21c8a0 ! 'a3.w.o_bias' [0s] Symbol 98418117-7906-1596-44f9-794cdd933160 64 | 'm3.ln.w.bias_arr' [768] Array ffb966d13a9accd30ff6a50e ! 'm3.ln.w.bias' [0s] Symbol 31162427-3bfd-1d33-8d00-38ec42650644 65 | 'm3.ln.w.scale_arr' [768] Array ec18b1f325e3f177ec16c16c ! 'm3.ln.w.scale' [0s] Symbol 8a7d43b5-7863-3074-b797-0386fee29476 66 | 'm3.w.proj_in_arr' [3072,768] Array c5631b41140dc5491c5030a9 ! 'm3.w.proj_in' [5s,0s] Symbol 21636369-8b52-9b4a-97b7-50923ceb3ffd 67 | 'm3.w.in_bias_arr' [3072] Array 2152a961e570255ea86daf82 ! 'm3.w.in_bias' [5s] Symbol 795b929e-9a9a-80fd-ea7b-5bf55eb561a4 68 | 'm3.w.proj_out_arr' [768,3072] Array bf8e9fbd21d5a57de0fbb490 ! 'm3.w.proj_out' [0s,5s] Symbol 9b08923d-10c6-7fd9-94b2-b8fda02f34a6 69 | 'm3.w.out_bias_arr' [768] Array d8f0a48ba7bae0361369b252 ! 'm3.w.out_bias' [0s] Symbol 781f9c58-d664-5fa9-e8a8-529f035efa25 70 | 'a4.ln.w.bias_arr' [768] Array 7b6ccf9baea480dc5e88d586 ! 'a4.ln.w.bias' [0s] Symbol 3f508249-2d83-a823-3fb6-2d2c81862fc9 71 | 'a4.ln.w.scale_arr' [768] Array 2859517c8c3897e4ee197bbd ! 'a4.ln.w.scale' [0s] Symbol f1cfd992-16df-6486-47ad-ec26793d0e45 72 | 'a4.w.q_arr' [12,64,768] Array 28973829120d1dc499e91220 ! 'a4.w.q' [4s,7s,0s] Symbol 43000de0-1b2e-d40e-d3ad-dccb2c33be0a 73 | 'a4.w.k_arr' [12,64,768] Array e2e8aa8745b2907346301909 ! 'a4.w.k' [4s,7s,0s] Symbol 42a00403-ce80-c4b0-a404-2bb3d4341aad 74 | 'a4.w.v_arr' [12,64,768] Array c1f23fcf295848fc3a20a5e0 ! 'a4.w.v' [4s,8s,0s] Symbol de08caa1-a081-7910-4a25-e4664f5253a0 75 | 'a4.w.o_arr' [12,768,64] Array 55fb56d8b48fbfc517576be8 ! 'a4.w.o' [4s,0s,8s] Symbol d8441b56-1633-2aca-5f55-2773e14b0190 76 | 'a4.w.q_bias_arr' [12,64] Array 2a10d58d68432aded19f3b36 ! 'a4.w.q_bias' [4s,7s] Symbol 06905269-ed6f-0b09-f165-c8ce36e2f24b 77 | 'a4.w.k_bias_arr' [12,64] Array 603691252d8ab8bfa1a6e2e0 ! 'a4.w.k_bias' [4s,7s] Symbol 2a318785-3184-ff27-4591-42deccea2645 78 | 'a4.w.v_bias_arr' [12,64] Array 06ed4e6850d16efdf58f7575 ! 'a4.w.v_bias' [4s,8s] Symbol d93936e1-daca-3c06-f5ff-0c03bb5d7385 79 | 'a4.w.o_bias_arr' [768] Array d182ee6b4f7ed1ca990500f9 ! 'a4.w.o_bias' [0s] Symbol 634f806f-abf4-a07c-5660-02249b191bf4 80 | 'm4.ln.w.bias_arr' [768] Array f0079d4c446cbc007dd3789c ! 'm4.ln.w.bias' [0s] Symbol 8534f457-38d0-48ec-0f10-99c6c3e1b258 81 | 'm4.ln.w.scale_arr' [768] Array 77eac218b9722ae4f26ef278 ! 'm4.ln.w.scale' [0s] Symbol c79d6793-46d4-ac7a-5c39-02b38963dc6e 82 | 'm4.w.proj_in_arr' [3072,768] Array 348547dafe6a52c77189d501 ! 'm4.w.proj_in' [5s,0s] Symbol b8a1abcd-1a69-16c7-4da4-f9fc3c6da5d7 83 | 'm4.w.in_bias_arr' [3072] Array ff2a3e4409b2459d56218016 ! 'm4.w.in_bias' [5s] Symbol 1710cf53-27ac-435a-7a97-c643656412a9 84 | 'm4.w.proj_out_arr' [768,3072] Array 26075ee90a389751db5efde9 ! 'm4.w.proj_out' [0s,5s] Symbol 8ca59966-66ce-ab36-0512-bd1311072231 85 | 'm4.w.out_bias_arr' [768] Array 835944b929354f6f7ea43ff1 ! 'm4.w.out_bias' [0s] Symbol fd724452-ccea-71ff-4a14-876aeaff1a09 86 | 'a5.ln.w.bias_arr' [768] Array ce1bfc257434f729ef57dd97 ! 'a5.ln.w.bias' [0s] Symbol f5cae3bf-3729-c619-c60a-3cab359eeefb 87 | 'a5.ln.w.scale_arr' [768] Array 6c4c9644e0da0388788edf89 ! 'a5.ln.w.scale' [0s] Symbol 2a9eba0c-df56-1d80-2a75-9159fb7ff337 88 | 'a5.w.q_arr' [12,64,768] Array 149e642ff2a536696ddba311 ! 'a5.w.q' [4s,7s,0s] Symbol 617959ce-3f1f-65a8-de52-71007814e8a2 89 | 'a5.w.k_arr' [12,64,768] Array 6001e89ade20f3a2050e00aa ! 'a5.w.k' [4s,7s,0s] Symbol 687c966c-377b-9aa2-bb2e-db20035b7399 90 | 'a5.w.v_arr' [12,64,768] Array c5d28eeab4a0a80be5023967 ! 'a5.w.v' [4s,8s,0s] Symbol c30d8b76-28db-d25e-63b2-29f1c4069545 91 | 'a5.w.o_arr' [12,768,64] Array c3b754d87d710ca1e42b455c ! 'a5.w.o' [4s,0s,8s] Symbol 21da8978-206f-5c66-71e0-c07e9e115e4b 92 | 'a5.w.q_bias_arr' [12,64] Array 0ae54031f880a9eac0c17eb0 ! 'a5.w.q_bias' [4s,7s] Symbol 3fd42359-92ed-cf45-1a1a-fe878b33e968 93 | 'a5.w.k_bias_arr' [12,64] Array 2f0167777ca8ab80fd464b67 ! 'a5.w.k_bias' [4s,7s] Symbol de11cc9d-ea95-9c21-2e9c-82b1478c281d 94 | 'a5.w.v_bias_arr' [12,64] Array 7777d94ba6517daa3b5fa726 ! 'a5.w.v_bias' [4s,8s] Symbol 9e30691c-2386-42ea-126a-1e48cc11d357 95 | 'a5.w.o_bias_arr' [768] Array b7b0216c848a60f4a44f4153 ! 'a5.w.o_bias' [0s] Symbol 015c33b2-df14-61aa-f8eb-18b900745130 96 | 'm5.ln.w.bias_arr' [768] Array 7acc8d333ab8c75be75be5a4 ! 'm5.ln.w.bias' [0s] Symbol 0d464138-a623-3255-3fc1-ea36f17fd374 97 | 'm5.ln.w.scale_arr' [768] Array f8be135a601d087ea26711c2 ! 'm5.ln.w.scale' [0s] Symbol 5f2dd97f-1cfb-10f6-2827-688de6a16a3b 98 | 'm5.w.proj_in_arr' [3072,768] Array be1a55e778cb20257bf2871c ! 'm5.w.proj_in' [5s,0s] Symbol 5bc8fbbc-bde5-c099-4164-d8399f767c45 99 | 'm5.w.in_bias_arr' [3072] Array 2b0df5eb66996e8b36613254 ! 'm5.w.in_bias' [5s] Symbol d76d4330-f144-6bea-b0c1-1fdecb91ce37 100 | 'm5.w.proj_out_arr' [768,3072] Array d64410fe31249e6bb545b4a4 ! 'm5.w.proj_out' [0s,5s] Symbol 87b0b125-ec1d-7da0-a6eb-8c9ebd69fe29 101 | 'm5.w.out_bias_arr' [768] Array 1b6463a649b155cc070f6f74 ! 'm5.w.out_bias' [0s] Symbol c6a53877-7733-0bdb-d721-0dff076ce2ef 102 | 'a6.ln.w.bias_arr' [768] Array e299e86fbcfe93a6f730c296 ! 'a6.ln.w.bias' [0s] Symbol b3642b19-3279-3637-c16c-f5c51801fd9a 103 | 'a6.ln.w.scale_arr' [768] Array b2da920807ab8590086edebc ! 'a6.ln.w.scale' [0s] Symbol 18f918e2-4a8b-0188-cbe1-9514a28a0aaa 104 | 'a6.w.q_arr' [12,64,768] Array 31b41bbae8bda7902425476d ! 'a6.w.q' [4s,7s,0s] Symbol e91b4ad1-69fc-5360-df5c-a32ebad5ccc2 105 | 'a6.w.k_arr' [12,64,768] Array 57bff1be9df75666cdc14fea ! 'a6.w.k' [4s,7s,0s] Symbol b313fc7e-8db9-b92c-903c-2ac9316774fe 106 | 'a6.w.v_arr' [12,64,768] Array 9be3c843255be4658b73f11f ! 'a6.w.v' [4s,8s,0s] Symbol 168e5087-af89-5f5b-9c2c-0ac2cda95957 107 | 'a6.w.o_arr' [12,768,64] Array de48565aabc20b14f30a3a54 ! 'a6.w.o' [4s,0s,8s] Symbol 68f22599-ccdf-540b-5cb5-3ec017d7ab26 108 | 'a6.w.q_bias_arr' [12,64] Array faef4ab9cf92370da98061ae ! 'a6.w.q_bias' [4s,7s] Symbol 181e290a-ae9a-f169-8a0c-510089ce5ef7 109 | 'a6.w.k_bias_arr' [12,64] Array d18d97d594fff8cff21d4250 ! 'a6.w.k_bias' [4s,7s] Symbol a9b3d1a2-43f9-300c-ba98-666ace1c9c17 110 | 'a6.w.v_bias_arr' [12,64] Array 5033e036e1268879e459c0a3 ! 'a6.w.v_bias' [4s,8s] Symbol fd802060-55e8-b3eb-6cb9-185ed822e2f9 111 | 'a6.w.o_bias_arr' [768] Array 6b8673e6be3d803c85c0bb2f ! 'a6.w.o_bias' [0s] Symbol b31a5bf3-71f9-70cf-401f-e4fcce06294d 112 | 'm6.ln.w.bias_arr' [768] Array f40f03605390f6c910fa8681 ! 'm6.ln.w.bias' [0s] Symbol 059a91e1-c527-e279-51c3-42505f877031 113 | 'm6.ln.w.scale_arr' [768] Array aef80a118bc687dd66057325 ! 'm6.ln.w.scale' [0s] Symbol 32b7228f-cd4a-5557-7d24-b39645cf8aa4 114 | 'm6.w.proj_in_arr' [3072,768] Array fe2b91126965219baa39be86 ! 'm6.w.proj_in' [5s,0s] Symbol 14a03569-d26b-9496-92e5-dfe8cb1855fe 115 | 'm6.w.in_bias_arr' [3072] Array 47aab50b3e6af44b206a4f3b ! 'm6.w.in_bias' [5s] Symbol 096d3737-42f9-a039-c320-a4737c2b3abe 116 | 'm6.w.proj_out_arr' [768,3072] Array 5fcb338ac94cf64cbf8102b4 ! 'm6.w.proj_out' [0s,5s] Symbol 9623d7cf-a9ae-7a34-2544-99c7001d9a88 117 | 'm6.w.out_bias_arr' [768] Array 5504f5ec7d1814acfdfc1525 ! 'm6.w.out_bias' [0s] Symbol bc1e3ac1-c27d-b4ec-f72c-2c2678629522 118 | 'a7.ln.w.bias_arr' [768] Array a8bc414d96490a6f79f9d333 ! 'a7.ln.w.bias' [0s] Symbol 923a7369-94e3-bf91-1a61-dbe22e44158b 119 | 'a7.ln.w.scale_arr' [768] Array 4b19ef0ad24c89b721adedcd ! 'a7.ln.w.scale' [0s] Symbol 18f135d2-5f55-7203-3018-50c5a38fd547 120 | 'a7.w.q_arr' [12,64,768] Array f2731beb1a4f1d51efb14fcc ! 'a7.w.q' [4s,7s,0s] Symbol 90c192cf-d3ac-94af-0f21-ddb66cad4a26 121 | 'a7.w.k_arr' [12,64,768] Array 912872639c6f384d4c26469b ! 'a7.w.k' [4s,7s,0s] Symbol 0fd630f1-f29d-0da9-953f-48f1a09f76b5 122 | 'a7.w.v_arr' [12,64,768] Array f7cadc0ee191432d7474ee5b ! 'a7.w.v' [4s,8s,0s] Symbol 8e81973e-0bec-d7b0-3898-d190f9ebdacc 123 | 'a7.w.o_arr' [12,768,64] Array 621079f1fecb8c37091ec7db ! 'a7.w.o' [4s,0s,8s] Symbol 92276658-1e27-a1c0-8a6a-63ec24ede6a4 124 | 'a7.w.q_bias_arr' [12,64] Array 8705718924f1e3a56a9825b8 ! 'a7.w.q_bias' [4s,7s] Symbol a170b338-3926-3059-f28c-105d1fb17c23 125 | 'a7.w.k_bias_arr' [12,64] Array b03aa64dde9523701323bb0a ! 'a7.w.k_bias' [4s,7s] Symbol 0cb1e29c-658c-da14-95e6-0af593bd04cf 126 | 'a7.w.v_bias_arr' [12,64] Array dbda3f4bd47a592d4b3b3a5a ! 'a7.w.v_bias' [4s,8s] Symbol 6b4cb242-4a23-d596-2217-beaddbc496cb 127 | 'a7.w.o_bias_arr' [768] Array 1205d32607eb94acc9f58f79 ! 'a7.w.o_bias' [0s] Symbol ae97ba94-d0ed-a82f-8f6d-05584ef8aa38 128 | 'm7.ln.w.bias_arr' [768] Array 55067071af4948b68b65b794 ! 'm7.ln.w.bias' [0s] Symbol 6b0d549b-6f03-675a-1600-a35a099950d8 129 | 'm7.ln.w.scale_arr' [768] Array 91f4297e51165c716211dfe2 ! 'm7.ln.w.scale' [0s] Symbol 8d116ece-1738-f7d9-3d9c-172411e20b8f 130 | 'm7.w.proj_in_arr' [3072,768] Array 37cf189e7b7b24371041a474 ! 'm7.w.proj_in' [5s,0s] Symbol 6513270e-269e-0d37-f2a7-4de452e6b438 131 | 'm7.w.in_bias_arr' [3072] Array 785b0f7b3bceca6091780cac ! 'm7.w.in_bias' [5s] Symbol d23f0824-128b-2f33-0c5c-7fd0a6a3a450 132 | 'm7.w.proj_out_arr' [768,3072] Array dca283786ec766a460bf5712 ! 'm7.w.proj_out' [0s,5s] Symbol 9531985d-5d9d-c9f8-1818-e811892f902b 133 | 'm7.w.out_bias_arr' [768] Array f5f20d6effdb5941c6dc9214 ! 'm7.w.out_bias' [0s] Symbol 36f675cc-81e7-4ef5-e8e2-5d940ed90475 134 | 'a8.ln.w.bias_arr' [768] Array b6de5bf0e215bdf5efffc2c6 ! 'a8.ln.w.bias' [0s] Symbol 1607b1c4-b0f9-1306-3c02-e56756a3e957 135 | 'a8.ln.w.scale_arr' [768] Array aa478f6a0a43f451bf83f4a3 ! 'a8.ln.w.scale' [0s] Symbol 844dbc0c-a654-23a9-e744-b24e7f61701e 136 | 'a8.w.q_arr' [12,64,768] Array 4d61af1f9a5d7fc6e9f01f01 ! 'a8.w.q' [4s,7s,0s] Symbol 67164890-d49d-0ac1-e5b8-063831360a40 137 | 'a8.w.k_arr' [12,64,768] Array f58a6eaaf74077af93d5516f ! 'a8.w.k' [4s,7s,0s] Symbol 852a5fba-444a-df42-b37f-5722051e2670 138 | 'a8.w.v_arr' [12,64,768] Array 95a0f431dbb3911151e4cde7 ! 'a8.w.v' [4s,8s,0s] Symbol a9b7e3ea-1d1d-784f-b9db-434b610b1631 139 | 'a8.w.o_arr' [12,768,64] Array 69a4f0c561b50c159efe21f5 ! 'a8.w.o' [4s,0s,8s] Symbol d45c39a3-9ec3-53c1-62e9-17d310269470 140 | 'a8.w.q_bias_arr' [12,64] Array dc1e6a8bd49b8cfa86716676 ! 'a8.w.q_bias' [4s,7s] Symbol c24f6aa8-3bf3-6a14-7c2f-7ad016edc5d4 141 | 'a8.w.k_bias_arr' [12,64] Array 2446c2072045ee89c79c3ac9 ! 'a8.w.k_bias' [4s,7s] Symbol e941aa79-e6ed-af80-796d-3bc4685ca8af 142 | 'a8.w.v_bias_arr' [12,64] Array b13a3c4410d068d3d1b36ca7 ! 'a8.w.v_bias' [4s,8s] Symbol d0718c1a-fdd9-a78d-18df-f3934223aa56 143 | 'a8.w.o_bias_arr' [768] Array 4721668a7c2023ebcb814020 ! 'a8.w.o_bias' [0s] Symbol 0edca4ec-a92d-04a3-1b94-1f4360908405 144 | 'm8.ln.w.bias_arr' [768] Array 79cfcee256b66fad26ecd2ae ! 'm8.ln.w.bias' [0s] Symbol 7cc661e9-7589-ca4a-07c1-5471a4517d6c 145 | 'm8.ln.w.scale_arr' [768] Array 43b9f1936fe05a6e23d3465e ! 'm8.ln.w.scale' [0s] Symbol 92b850ad-7eb7-2f82-63f6-5da874007cb4 146 | 'm8.w.proj_in_arr' [3072,768] Array d605100a49084a8a18f57a97 ! 'm8.w.proj_in' [5s,0s] Symbol 6018366c-f658-f7a7-5ed3-4fe53a096533 147 | 'm8.w.in_bias_arr' [3072] Array 59c5e42508639af57fb58f8a ! 'm8.w.in_bias' [5s] Symbol 0b3510b0-b46e-e1da-3170-17a6205738d1 148 | 'm8.w.proj_out_arr' [768,3072] Array 598e4a2dbf40234543b028bf ! 'm8.w.proj_out' [0s,5s] Symbol cfaf0010-3f58-4ad4-2308-24d215ceb3a1 149 | 'm8.w.out_bias_arr' [768] Array 8cef5a8a889c3c7bccdedfbf ! 'm8.w.out_bias' [0s] Symbol 6694f229-359b-1548-81a0-d5b3ffc6e35c 150 | 'a9.ln.w.bias_arr' [768] Array 26c4b5d0911ebe27531ad2aa ! 'a9.ln.w.bias' [0s] Symbol c5c7d186-1674-518d-e3bb-41b36bf82959 151 | 'a9.ln.w.scale_arr' [768] Array 2d757f5523248058d3729e2e ! 'a9.ln.w.scale' [0s] Symbol 6583d614-35bb-5c11-e950-27004448a6a1 152 | 'a9.w.q_arr' [12,64,768] Array 749e3921ccf4a9eb24cf5025 ! 'a9.w.q' [4s,7s,0s] Symbol f3868254-73b7-a490-f23b-2cc4b4174a67 153 | 'a9.w.k_arr' [12,64,768] Array 2eecd09c53f66cbfee410f15 ! 'a9.w.k' [4s,7s,0s] Symbol 21e6a46f-1c67-0ea9-0d24-3a163cee5e2c 154 | 'a9.w.v_arr' [12,64,768] Array 34833bf962d86848248ea9bc ! 'a9.w.v' [4s,8s,0s] Symbol b03da701-c632-976a-1036-3c5f972651da 155 | 'a9.w.o_arr' [12,768,64] Array 5cc649a06844f7b26b01c8b8 ! 'a9.w.o' [4s,0s,8s] Symbol 3478442b-4a8a-a593-eb40-a9b81a070205 156 | 'a9.w.q_bias_arr' [12,64] Array 6c59a5a299a7f49872daa908 ! 'a9.w.q_bias' [4s,7s] Symbol 2b1e1885-283b-73a6-6c2e-a417b99de255 157 | 'a9.w.k_bias_arr' [12,64] Array c0d48503c6c4a7ae36187369 ! 'a9.w.k_bias' [4s,7s] Symbol fdb119a9-ec80-1bdf-df29-65b3819ad93b 158 | 'a9.w.v_bias_arr' [12,64] Array a54960ff07ecbf64388bea51 ! 'a9.w.v_bias' [4s,8s] Symbol e323bb2a-bf00-188d-ca22-e4c76237dbe6 159 | 'a9.w.o_bias_arr' [768] Array b14a34e2b297c78ec57ef0c9 ! 'a9.w.o_bias' [0s] Symbol cb01c357-b9c7-e435-396b-cb8fac9abb0c 160 | 'm9.ln.w.bias_arr' [768] Array 3b33846598253ad0f7379d90 ! 'm9.ln.w.bias' [0s] Symbol b339a476-9ddc-c6f8-efb6-fbfe8de4ab47 161 | 'm9.ln.w.scale_arr' [768] Array 74085c06622d359f30674894 ! 'm9.ln.w.scale' [0s] Symbol 2b5ebaa0-6107-6dc3-ba6a-ce6c0a78250f 162 | 'm9.w.proj_in_arr' [3072,768] Array 1e02ff436776676edcd4c6a8 ! 'm9.w.proj_in' [5s,0s] Symbol 4462ebfc-5f91-5ef0-9cfb-ac6e7687a66e 163 | 'm9.w.in_bias_arr' [3072] Array 368584834a25da3851f88775 ! 'm9.w.in_bias' [5s] Symbol ad38835e-ddd6-ff55-2fa7-3207237751aa 164 | 'm9.w.proj_out_arr' [768,3072] Array 6d81000d8018038ce5e4d783 ! 'm9.w.proj_out' [0s,5s] Symbol 76b67451-80b6-5386-569c-803601a5ba50 165 | 'm9.w.out_bias_arr' [768] Array 3b6757c0ac358a359a0a5d38 ! 'm9.w.out_bias' [0s] Symbol 558298e2-14b0-44d7-9acd-8acde5f6db1d 166 | 'a10.ln.w.bias_arr' [768] Array 204fe3b2ac0b8aa6c9f3c4cf ! 'a10.ln.w.bias' [0s] Symbol 31e875ba-224c-0601-3c53-d0e30109c207 167 | 'a10.ln.w.scale_arr' [768] Array e9582b344eb52610b42bb6e4 ! 'a10.ln.w.scale' [0s] Symbol 894deab4-4d88-450f-e8da-c663f0e58650 168 | 'a10.w.q_arr' [12,64,768] Array 26aa0a91dd167a3053f12d12 ! 'a10.w.q' [4s,7s,0s] Symbol 6b9f15c4-0b68-0c1c-5c74-e45eff1e5bef 169 | 'a10.w.k_arr' [12,64,768] Array 45618735b57ee6af6420abb8 ! 'a10.w.k' [4s,7s,0s] Symbol d3ac535f-489b-340f-6bd7-f50361b0ee09 170 | 'a10.w.v_arr' [12,64,768] Array a54ca681d6189ce0b4b64c3f ! 'a10.w.v' [4s,8s,0s] Symbol 5cd2875e-a96e-c2b3-4d98-4bffaf949e5e 171 | 'a10.w.o_arr' [12,768,64] Array 0f6fdeaf0051cd3aa07e5fd9 ! 'a10.w.o' [4s,0s,8s] Symbol 708cc1b6-f829-d29f-3d48-06c2fb7f6f5d 172 | 'a10.w.q_bias_arr' [12,64] Array 33898cfe3d6fc8caed150352 ! 'a10.w.q_bias' [4s,7s] Symbol 5ae6a228-9a6a-b329-2381-23e5dc338383 173 | 'a10.w.k_bias_arr' [12,64] Array 3ea0abb517cb257431b3e19c ! 'a10.w.k_bias' [4s,7s] Symbol 2cb7362c-74f2-e2ed-4327-79eeacca7f0d 174 | 'a10.w.v_bias_arr' [12,64] Array ed49ed1463e07e0ef76d921b ! 'a10.w.v_bias' [4s,8s] Symbol dc2c2e2c-c491-04d0-74f9-42cb220adb0a 175 | 'a10.w.o_bias_arr' [768] Array 587315312c11c6c605064b66 ! 'a10.w.o_bias' [0s] Symbol 953b00b0-0b54-aa22-600f-ecc19d02fc90 176 | 'm10.ln.w.bias_arr' [768] Array 4f224167100c81e818abe9a6 ! 'm10.ln.w.bias' [0s] Symbol 137a9777-53e8-eb43-7d76-3fb9854a9657 177 | 'm10.ln.w.scale_arr' [768] Array cd9d0cc126c56bd63f063934 ! 'm10.ln.w.scale' [0s] Symbol bedc25e6-f3eb-cf12-f3d0-6f863fffc830 178 | 'm10.w.proj_in_arr' [3072,768] Array 3af123dcc5d43cded8ba06a6 ! 'm10.w.proj_in' [5s,0s] Symbol 7b89296c-6dcb-ac50-0857-7eb1924770d3 179 | 'm10.w.in_bias_arr' [3072] Array ef44bb7ea62544a1454e4239 ! 'm10.w.in_bias' [5s] Symbol 766bad07-34c2-da80-03cc-0f2793fdcab8 180 | 'm10.w.proj_out_arr' [768,3072] Array 471ca6d3aab1d42bfc11fdc8 ! 'm10.w.proj_out' [0s,5s] Symbol 470b9805-d2d6-b877-7dc5-9a3ad035d259 181 | 'm10.w.out_bias_arr' [768] Array 5a0ba6291353be100550a8d7 ! 'm10.w.out_bias' [0s] Symbol 08ceac39-2904-cdef-cf84-b683a749f9c5 182 | 'a11.ln.w.bias_arr' [768] Array 0354348a5e5e76080c298fee ! 'a11.ln.w.bias' [0s] Symbol 320094ea-d7a9-4ded-9749-1e2370c6a5b8 183 | 'a11.ln.w.scale_arr' [768] Array 311eaac357f33fd1f57ae594 ! 'a11.ln.w.scale' [0s] Symbol 4b4d8474-a3ea-284d-3bd0-334684e55160 184 | 'a11.w.q_arr' [12,64,768] Array 27b8a5a29fdf846be6358ae5 ! 'a11.w.q' [4s,7s,0s] Symbol e3eff9c0-cf44-dd3f-89e7-d15f17362f25 185 | 'a11.w.k_arr' [12,64,768] Array d464ae8b3a20f746fbebb47e ! 'a11.w.k' [4s,7s,0s] Symbol 73f778aa-f6fa-5db8-656a-bd72fb710734 186 | 'a11.w.v_arr' [12,64,768] Array d569195b7e32ff26c65486b8 ! 'a11.w.v' [4s,8s,0s] Symbol d4ea65d0-03d7-1684-9f85-58a628518867 187 | 'a11.w.o_arr' [12,768,64] Array f9d9b15d185e03baf2252c5e ! 'a11.w.o' [4s,0s,8s] Symbol 99809225-3def-fa38-e12b-2b8f30b17d0b 188 | 'a11.w.q_bias_arr' [12,64] Array 69fe0e7784ec280b89e1f99c ! 'a11.w.q_bias' [4s,7s] Symbol 986e86cb-0ab8-ab67-a26b-7f62b1852f27 189 | 'a11.w.k_bias_arr' [12,64] Array 3357eb518d0775e73fdf2ec7 ! 'a11.w.k_bias' [4s,7s] Symbol a66b0d38-9d95-847e-bd29-9753a7677796 190 | 'a11.w.v_bias_arr' [12,64] Array c2a1a38e831a6c1120e00ca1 ! 'a11.w.v_bias' [4s,8s] Symbol 09208a65-0f3e-bdd3-102b-938b8743feb6 191 | 'a11.w.o_bias_arr' [768] Array f77dfbaf3d05a1857db87e9f ! 'a11.w.o_bias' [0s] Symbol 5387f613-76c4-68ae-c732-1cc007b37e14 192 | 'm11.ln.w.bias_arr' [768] Array 36eeea068893621f7c3f0660 ! 'm11.ln.w.bias' [0s] Symbol 2fa91425-cb00-8853-9d2c-67eda13ffe79 193 | 'm11.ln.w.scale_arr' [768] Array 6fc0cab3f8360125b859ec21 ! 'm11.ln.w.scale' [0s] Symbol 244caf9c-4dab-b481-7253-edc618187993 194 | 'm11.w.proj_in_arr' [3072,768] Array 3abb8fbb5fee4c667aa80af8 ! 'm11.w.proj_in' [5s,0s] Symbol db5b5fab-8f4d-3e27-dda1-494c73cf256d 195 | 'm11.w.in_bias_arr' [3072] Array 2c164f2e48f91e3ced0c7ab0 ! 'm11.w.in_bias' [5s] Symbol 73ab4876-7734-d7c1-c7fd-e805ec99108d 196 | 'm11.w.proj_out_arr' [768,3072] Array a729aa0277a2f6ecc729843d ! 'm11.w.proj_out' [0s,5s] Symbol 309d6b79-965e-da32-dae4-45508201e2bd 197 | 'm11.w.out_bias_arr' [768] Array 29c898d490237f73d915e15e ! 'm11.w.out_bias' [0s] Symbol 79cb9e86-830c-71c2-cdcc-69292f45e678 198 | 'final.ln.w.bias_arr' [768] Array a6f0ac1de32ab667e10fcd1d ! 'final.ln.w.bias' 199 | 'final.ln.w.scale_arr' [768] Array ea8915931fdea0d8e8f39b20 ! 'final.ln.w.scale' 200 | 't.w.unembed_arr' [50257,768] Array 1f668798fb95d0e16b2a0143 ! 't.w.unembed' 201 | 202 | # originally gelu_twelve_layers (aka gpt2-small) 203 | --------------------------------------------------------------------------------