├── .gitignore
├── train.py
├── analysis.py
├── README.md
├── trainer.py
├── buffer.py
├── utils.py
└── crosscoder.py
/.gitignore:
--------------------------------------------------------------------------------
1 | wandb/*
2 | __pycache__/*
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # %%
2 | from utils import *
3 | from trainer import Trainer
4 | # %%
5 | device = 'cuda:0'
6 |
7 | base_model = HookedTransformer.from_pretrained(
8 | "gemma-2-2b",
9 | device=device,
10 | )
11 |
12 | chat_model = HookedTransformer.from_pretrained(
13 | "gemma-2-2b-it",
14 | device=device,
15 | )
16 |
17 | # %%
18 | all_tokens = load_pile_lmsys_mixed_tokens()
19 |
20 | # %%
21 | default_cfg = {
22 | "seed": 49,
23 | "batch_size": 4096,
24 | "buffer_mult": 128,
25 | "lr": 5e-5,
26 | "num_tokens": 400_000_000,
27 | "l1_coeff": 2,
28 | "beta1": 0.9,
29 | "beta2": 0.999,
30 | "d_in": base_model.cfg.d_model,
31 | "dict_size": 2**14,
32 | "seq_len": 1024,
33 | "enc_dtype": "fp32",
34 | "model_name": "gemma-2-2b",
35 | "site": "resid_pre",
36 | "device": "cuda:0",
37 | "model_batch_size": 4,
38 | "log_every": 100,
39 | "save_every": 30000,
40 | "dec_init_norm": 0.08,
41 | "hook_point": "blocks.14.hook_resid_pre",
42 | "wandb_project": "YOUR_WANDB_PROJECT",
43 | "wandb_entity": "YOUR_WANDB_ENTITY",
44 | }
45 | cfg = arg_parse_update_cfg(default_cfg)
46 |
47 | trainer = Trainer(cfg, base_model, chat_model, all_tokens)
48 | trainer.train()
49 | # %%
--------------------------------------------------------------------------------
/analysis.py:
--------------------------------------------------------------------------------
1 | # %%
2 | from utils import *
3 | from crosscoder import CrossCoder
4 | torch.set_grad_enabled(False);
5 | # %%
6 | cross_coder = CrossCoder.load_from_hf()
7 |
8 | # %%
9 | norms = cross_coder.W_dec.norm(dim=-1)
10 | norms.shape
11 | # %%
12 | relative_norms = norms[:, 1] / norms.sum(dim=-1)
13 | relative_norms.shape
14 | # %%
15 |
16 | fig = px.histogram(
17 | relative_norms.detach().cpu().numpy(),
18 | title="Gemma 2 2B Base vs IT Model Diff",
19 | labels={"value": "Relative decoder norm strength"},
20 | nbins=200,
21 | )
22 |
23 | fig.update_layout(showlegend=False)
24 | fig.update_yaxes(title_text="Number of Latents")
25 |
26 | # Update x-axis ticks
27 | fig.update_xaxes(
28 | tickvals=[0, 0.25, 0.5, 0.75, 1.0],
29 | ticktext=['0', '0.25', '0.5', '0.75', '1.0']
30 | )
31 |
32 | fig.show()
33 |
34 | # %%
35 | shared_latent_mask = (relative_norms < 0.7) & (relative_norms > 0.3)
36 | shared_latent_mask.shape
37 | # %%
38 | # Cosine similarity of recoder vectors between models
39 |
40 | cosine_sims = (cross_coder.W_dec[:, 0, :] * cross_coder.W_dec[:, 1, :]).sum(dim=-1) / (cross_coder.W_dec[:, 0, :].norm(dim=-1) * cross_coder.W_dec[:, 1, :].norm(dim=-1))
41 | cosine_sims.shape
42 | # %%
43 | import plotly.express as px
44 | import torch
45 |
46 | fig = px.histogram(
47 | cosine_sims[shared_latent_mask].to(torch.float32).detach().cpu().numpy(),
48 | #title="Cosine similarity of decoder vectors between models",
49 | log_y=True, # Sets the y-axis to log scale
50 | range_x=[-1, 1], # Sets the x-axis range from -1 to 1
51 | nbins=100, # Adjust this value to change the number of bins
52 | labels={"value": "Cosine similarity of decoder vectors between models"}
53 | )
54 |
55 | fig.update_layout(showlegend=False)
56 | fig.update_yaxes(title_text="Number of Latents (log scale)")
57 |
58 | fig.show()
59 | # %%
60 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TLDR
2 |
3 | Open source replication of [Anthropic's Crosscoders for Model Diffing](https://transformer-circuits.pub/2024/crosscoders/index.html#model-diffing).
4 | The crosscoder was trained to model diff the Gemma-2 2b base and IT residual stream at the middle layer.
5 |
6 | See this [blog post](https://www.lesswrong.com/posts/srt6JXsRMtmqAJavD/open-source-replication-of-anthropic-s-crosscoder-paper-for) for more details.
7 |
8 | # Reading This Codebase
9 |
10 | This implementation is adapted from Neel Nanda's code at https://github.com/neelnanda-io/Crosscoders
11 |
12 | * `train.py` is the main training script for the crosscoder. Run it with `python train.py`.
13 | * `trainer.py` contains the pytorch boilerplate code that actually trains the crosscoder.
14 | * `crosscoder.py` contains the pytorch implementation of the crosscoder. It was implemented to diff two different models, but should be easily hackable to work with an arbitrary number of models.
15 | * `buffer.py` contains code to extract activations from both models, concatenate them, and store them in a buffer which is shuffled and periodically refreshed during training.
16 | * `analysis.py` is a short notebook replicating some of the results from the Anthropic paper. See this [colab notebook](https://colab.research.google.com/drive/124ODki4dUjfi21nuZPHRySALx9I74YHj?usp=sharing) for a more comprehensive demo.
17 |
18 | It won't work out of the box, but hopefully it's pretty hackable. Some tips:
19 | * In `train.py` I just set the cfg by editing the code, rather than using command line arguments. You'll need to change the "wandb_entity" and "wandb_entity" in the cfg dict.
20 | * You'll need to create a checkpoints dir in `/workspace/crosscoder-model-diff-replication/checkpoints` (or change this path in the code). I would sanity check this with a short test run to make sure your weights will be properly saved at the end of training.
21 | * We load training data from https://huggingface.co/datasets/ckkissane/pile-lmsys-mix-1m-tokenized-gemma-2 as a global tensor called all_tokens, and pass this to the Trainer in `train.py`. This is very hacky, but should be easy to swap out if needed.
22 | * In `buffer.py` we separately normalize both the base and chat activations such that they both have average norm sqrt(d_model). This should be handled for you during training, but note that you'll also need to normalize activations during analysis (or fold the normalization scaling factors into the crosscoder weights).
23 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | from crosscoder import CrossCoder
3 | from buffer import Buffer
4 | import tqdm
5 |
6 | from torch.nn.utils import clip_grad_norm_
7 | class Trainer:
8 | def __init__(self, cfg, model_A, model_B, all_tokens):
9 | self.cfg = cfg
10 | self.model_A = model_A
11 | self.model_B = model_B
12 | self.crosscoder = CrossCoder(cfg)
13 | self.buffer = Buffer(cfg, model_A, model_B, all_tokens)
14 | self.total_steps = cfg["num_tokens"] // cfg["batch_size"]
15 |
16 | self.optimizer = torch.optim.Adam(
17 | self.crosscoder.parameters(),
18 | lr=cfg["lr"],
19 | betas=(cfg["beta1"], cfg["beta2"]),
20 | )
21 | self.scheduler = torch.optim.lr_scheduler.LambdaLR(
22 | self.optimizer, self.lr_lambda
23 | )
24 | self.step_counter = 0
25 |
26 | wandb.init(project=cfg["wandb_project"], entity=cfg["wandb_entity"])
27 |
28 | def lr_lambda(self, step):
29 | if step < 0.8 * self.total_steps:
30 | return 1.0
31 | else:
32 | return 1.0 - (step - 0.8 * self.total_steps) / (0.2 * self.total_steps)
33 |
34 | def get_l1_coeff(self):
35 | # Linearly increases from 0 to cfg["l1_coeff"] over the first 0.05 * self.total_steps steps, then keeps it constant
36 | if self.step_counter < 0.05 * self.total_steps:
37 | return self.cfg["l1_coeff"] * self.step_counter / (0.05 * self.total_steps)
38 | else:
39 | return self.cfg["l1_coeff"]
40 |
41 | def step(self):
42 | acts = self.buffer.next()
43 | losses = self.crosscoder.get_losses(acts)
44 | loss = losses.l2_loss + self.get_l1_coeff() * losses.l1_loss
45 | loss.backward()
46 | clip_grad_norm_(self.crosscoder.parameters(), max_norm=1.0)
47 | self.optimizer.step()
48 | self.scheduler.step()
49 | self.optimizer.zero_grad()
50 |
51 | loss_dict = {
52 | "loss": loss.item(),
53 | "l2_loss": losses.l2_loss.item(),
54 | "l1_loss": losses.l1_loss.item(),
55 | "l0_loss": losses.l0_loss.item(),
56 | "l1_coeff": self.get_l1_coeff(),
57 | "lr": self.scheduler.get_last_lr()[0],
58 | "explained_variance": losses.explained_variance.mean().item(),
59 | "explained_variance_A": losses.explained_variance_A.mean().item(),
60 | "explained_variance_B": losses.explained_variance_B.mean().item(),
61 | }
62 | self.step_counter += 1
63 | return loss_dict
64 |
65 | def log(self, loss_dict):
66 | wandb.log(loss_dict, step=self.step_counter)
67 | print(loss_dict)
68 |
69 | def save(self):
70 | self.crosscoder.save()
71 |
72 | def train(self):
73 | self.step_counter = 0
74 | try:
75 | for i in tqdm.trange(self.total_steps):
76 | loss_dict = self.step()
77 | if i % self.cfg["log_every"] == 0:
78 | self.log(loss_dict)
79 | if (i + 1) % self.cfg["save_every"] == 0:
80 | self.save()
81 | finally:
82 | self.save()
--------------------------------------------------------------------------------
/buffer.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | from transformer_lens import ActivationCache
3 | import tqdm
4 |
5 | class Buffer:
6 | """
7 | This defines a data buffer, to store a stack of acts across both model that can be used to train the autoencoder. It'll automatically run the model to generate more when it gets halfway empty.
8 | """
9 |
10 | def __init__(self, cfg, model_A, model_B, all_tokens):
11 | assert model_A.cfg.d_model == model_B.cfg.d_model
12 | self.cfg = cfg
13 | self.buffer_size = cfg["batch_size"] * cfg["buffer_mult"]
14 | self.buffer_batches = self.buffer_size // (cfg["seq_len"] - 1)
15 | self.buffer_size = self.buffer_batches * (cfg["seq_len"] - 1)
16 | self.buffer = torch.zeros(
17 | (self.buffer_size, 2, model_A.cfg.d_model),
18 | dtype=torch.bfloat16,
19 | requires_grad=False,
20 | ).to(cfg["device"]) # hardcoding 2 for model diffing
21 | self.cfg = cfg
22 | self.model_A = model_A
23 | self.model_B = model_B
24 | self.token_pointer = 0
25 | self.first = True
26 | self.normalize = True
27 | self.all_tokens = all_tokens
28 |
29 | estimated_norm_scaling_factor_A = self.estimate_norm_scaling_factor(cfg["model_batch_size"], model_A)
30 | estimated_norm_scaling_factor_B = self.estimate_norm_scaling_factor(cfg["model_batch_size"], model_B)
31 |
32 | self.normalisation_factor = torch.tensor(
33 | [
34 | estimated_norm_scaling_factor_A,
35 | estimated_norm_scaling_factor_B,
36 | ],
37 | device="cuda:0",
38 | dtype=torch.float32,
39 | )
40 | self.refresh()
41 |
42 | @torch.no_grad()
43 | def estimate_norm_scaling_factor(self, batch_size, model, n_batches_for_norm_estimate: int = 100):
44 | # stolen from SAELens https://github.com/jbloomAus/SAELens/blob/6d6eaef343fd72add6e26d4c13307643a62c41bf/sae_lens/training/activations_store.py#L370
45 | norms_per_batch = []
46 | for i in tqdm.tqdm(
47 | range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor"
48 | ):
49 | tokens = self.all_tokens[i * batch_size : (i + 1) * batch_size]
50 | _, cache = model.run_with_cache(
51 | tokens,
52 | names_filter=self.cfg["hook_point"],
53 | return_type=None,
54 | )
55 | acts = cache[self.cfg["hook_point"]]
56 | # TODO: maybe drop BOS here
57 | norms_per_batch.append(acts.norm(dim=-1).mean().item())
58 | mean_norm = np.mean(norms_per_batch)
59 | scaling_factor = np.sqrt(model.cfg.d_model) / mean_norm
60 |
61 | return scaling_factor
62 |
63 | @torch.no_grad()
64 | def refresh(self):
65 | self.pointer = 0
66 | print("Refreshing the buffer!")
67 | with torch.autocast("cuda", torch.bfloat16):
68 | if self.first:
69 | num_batches = self.buffer_batches
70 | else:
71 | num_batches = self.buffer_batches // 2
72 | self.first = False
73 | for _ in tqdm.trange(0, num_batches, self.cfg["model_batch_size"]):
74 | tokens = self.all_tokens[
75 | self.token_pointer : min(
76 | self.token_pointer + self.cfg["model_batch_size"], num_batches
77 | )
78 | ]
79 | _, cache_A = self.model_A.run_with_cache(
80 | tokens, names_filter=self.cfg["hook_point"]
81 | )
82 | cache_A: ActivationCache
83 |
84 | _, cache_B = self.model_B.run_with_cache(
85 | tokens, names_filter=self.cfg["hook_point"]
86 | )
87 | cache_B: ActivationCache
88 |
89 | acts = torch.stack([cache_A[self.cfg["hook_point"]], cache_B[self.cfg["hook_point"]]], dim=0)
90 | acts = acts[:, :, 1:, :] # Drop BOS
91 | assert acts.shape == (2, tokens.shape[0], tokens.shape[1]-1, self.model_A.cfg.d_model) # [2, batch, seq_len, d_model]
92 | acts = einops.rearrange(
93 | acts,
94 | "n_layers batch seq_len d_model -> (batch seq_len) n_layers d_model",
95 | )
96 |
97 | self.buffer[self.pointer : self.pointer + acts.shape[0]] = acts
98 | self.pointer += acts.shape[0]
99 | self.token_pointer += self.cfg["model_batch_size"]
100 |
101 | self.pointer = 0
102 | self.buffer = self.buffer[
103 | torch.randperm(self.buffer.shape[0]).to(self.cfg["device"])
104 | ]
105 |
106 | @torch.no_grad()
107 | def next(self):
108 | out = self.buffer[self.pointer : self.pointer + self.cfg["batch_size"]].float()
109 | # out: [batch_size, n_layers, d_model]
110 | self.pointer += self.cfg["batch_size"]
111 | if self.pointer > self.buffer.shape[0] // 2 - self.cfg["batch_size"]:
112 | self.refresh()
113 | if self.normalize:
114 | out = out * self.normalisation_factor[None, :, None]
115 | return out
116 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | from IPython import get_ipython
4 |
5 | ipython = get_ipython()
6 | # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
7 | if ipython is not None:
8 | ipython.magic("load_ext autoreload")
9 | ipython.magic("autoreload 2")
10 |
11 | import plotly.io as pio
12 | pio.renderers.default = "jupyterlab"
13 |
14 | # Import stuff
15 | import einops
16 | import json
17 | import argparse
18 |
19 | from datasets import load_dataset
20 | from pathlib import Path
21 | import plotly.express as px
22 | from torch.distributions.categorical import Categorical
23 | from tqdm import tqdm
24 | import torch
25 | import numpy as np
26 | from transformer_lens import HookedTransformer
27 | from jaxtyping import Float
28 | from transformer_lens.hook_points import HookPoint
29 |
30 | from functools import partial
31 |
32 | from IPython.display import HTML
33 |
34 | from transformer_lens.utils import to_numpy
35 | import pandas as pd
36 |
37 | from html import escape
38 | import colorsys
39 |
40 |
41 | import wandb
42 |
43 | import plotly.graph_objects as go
44 |
45 | update_layout_set = {
46 | "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis",
47 | "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid",
48 | "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"
49 | }
50 |
51 | def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
52 | if isinstance(tensor, list):
53 | tensor = torch.stack(tensor)
54 | kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
55 | kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
56 | if "facet_labels" in kwargs_pre:
57 | facet_labels = kwargs_pre.pop("facet_labels")
58 | else:
59 | facet_labels = None
60 | if "color_continuous_scale" not in kwargs_pre:
61 | kwargs_pre["color_continuous_scale"] = "RdBu"
62 | fig = px.imshow(to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
63 | if facet_labels:
64 | for i, label in enumerate(facet_labels):
65 | fig.layout.annotations[i]['text'] = label
66 |
67 | fig.show(renderer)
68 |
69 | def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
70 | px.line(y=to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)
71 |
72 | def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, return_fig=False, **kwargs):
73 | x = to_numpy(x)
74 | y = to_numpy(y)
75 | fig = px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)
76 | if return_fig:
77 | return fig
78 | fig.show(renderer)
79 |
80 | def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs):
81 | # Helper function to plot multiple lines
82 | if type(lines_list)==torch.Tensor:
83 | lines_list = [lines_list[i] for i in range(lines_list.shape[0])]
84 | if x is None:
85 | x=np.arange(len(lines_list[0]))
86 | fig = go.Figure(layout={'title':title})
87 | fig.update_xaxes(title=xaxis)
88 | fig.update_yaxes(title=yaxis)
89 | for c, line in enumerate(lines_list):
90 | if type(line)==torch.Tensor:
91 | line = to_numpy(line)
92 | if labels is not None:
93 | label = labels[c]
94 | else:
95 | label = c
96 | fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs))
97 | if log_y:
98 | fig.update_layout(yaxis_type="log")
99 | fig.show()
100 |
101 | def bar(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
102 | px.bar(
103 | y=to_numpy(tensor),
104 | labels={"x": xaxis, "y": yaxis},
105 | template="simple_white",
106 | **kwargs).show(renderer)
107 |
108 | def create_html(strings, values, saturation=0.5, allow_different_length=False):
109 | # escape strings to deal with tabs, newlines, etc.
110 | escaped_strings = [escape(s, quote=True) for s in strings]
111 | processed_strings = [
112 | s.replace("\n", "
").replace("\t", " ").replace(" ", " ")
113 | for s in escaped_strings
114 | ]
115 |
116 | if isinstance(values, torch.Tensor) and len(values.shape)>1:
117 | values = values.flatten().tolist()
118 |
119 | if not allow_different_length:
120 | assert len(processed_strings) == len(values)
121 |
122 | # scale values
123 | max_value = max(max(values), -min(values))+1e-3
124 | scaled_values = [v / max_value * saturation for v in values]
125 |
126 | # create html
127 | html = ""
128 | for i, s in enumerate(processed_strings):
129 | if i{s}'
146 |
147 | display(HTML(html))
148 |
149 | # crosscoder stuff
150 |
151 | def arg_parse_update_cfg(default_cfg):
152 | """
153 | Helper function to take in a dictionary of arguments, convert these to command line arguments, look at what was passed in, and return an updated dictionary.
154 |
155 | If in Ipython, just returns with no changes
156 | """
157 | if get_ipython() is not None:
158 | # Is in IPython
159 | print("In IPython - skipped argparse")
160 | return default_cfg
161 | cfg = dict(default_cfg)
162 | parser = argparse.ArgumentParser()
163 | for key, value in default_cfg.items():
164 | if type(value) == bool:
165 | # argparse for Booleans is broken rip. Now you put in a flag to change the default --{flag} to set True, --{flag} to set False
166 | if value:
167 | parser.add_argument(f"--{key}", action="store_false")
168 | else:
169 | parser.add_argument(f"--{key}", action="store_true")
170 |
171 | else:
172 | parser.add_argument(f"--{key}", type=type(value), default=value)
173 | args = parser.parse_args()
174 | parsed_args = vars(args)
175 | cfg.update(parsed_args)
176 | print("Updated config")
177 | print(json.dumps(cfg, indent=2))
178 | return cfg
179 |
180 | def load_pile_lmsys_mixed_tokens():
181 | try:
182 | print("Loading data from disk")
183 | all_tokens = torch.load("/workspace/data/pile-lmsys-mix-1m-tokenized-gemma-2.pt")
184 | except:
185 | print("Data is not cached. Loading data from HF")
186 | data = load_dataset(
187 | "ckkissane/pile-lmsys-mix-1m-tokenized-gemma-2",
188 | split="train",
189 | cache_dir="/workspace/cache/"
190 | )
191 | data.save_to_disk("/workspace/data/pile-lmsys-mix-1m-tokenized-gemma-2.hf")
192 | data.set_format(type="torch", columns=["input_ids"])
193 | all_tokens = data["input_ids"]
194 | torch.save(all_tokens, "/workspace/data/pile-lmsys-mix-1m-tokenized-gemma-2.pt")
195 | print(f"Saved tokens to disk")
196 | return all_tokens
--------------------------------------------------------------------------------
/crosscoder.py:
--------------------------------------------------------------------------------
1 |
2 | from utils import *
3 |
4 | from torch import nn
5 | import pprint
6 | import torch.nn.functional as F
7 | from typing import Optional, Union
8 | from huggingface_hub import hf_hub_download
9 |
10 | from typing import NamedTuple
11 |
12 | DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
13 | SAVE_DIR = Path("/workspace/crosscoder-model-diff-replication/checkpoints")
14 |
15 | class LossOutput(NamedTuple):
16 | # loss: torch.Tensor
17 | l2_loss: torch.Tensor
18 | l1_loss: torch.Tensor
19 | l0_loss: torch.Tensor
20 | explained_variance: torch.Tensor
21 | explained_variance_A: torch.Tensor
22 | explained_variance_B: torch.Tensor
23 |
24 | class CrossCoder(nn.Module):
25 | def __init__(self, cfg):
26 | super().__init__()
27 | self.cfg = cfg
28 | d_hidden = self.cfg["dict_size"]
29 | d_in = self.cfg["d_in"]
30 | self.dtype = DTYPES[self.cfg["enc_dtype"]]
31 | torch.manual_seed(self.cfg["seed"])
32 | # hardcoding n_models to 2
33 | self.W_enc = nn.Parameter(
34 | torch.empty(2, d_in, d_hidden, dtype=self.dtype)
35 | )
36 | self.W_dec = nn.Parameter(
37 | torch.nn.init.normal_(
38 | torch.empty(
39 | d_hidden, 2, d_in, dtype=self.dtype
40 | )
41 | )
42 | )
43 | self.W_dec = nn.Parameter(
44 | torch.nn.init.normal_(
45 | torch.empty(
46 | d_hidden, 2, d_in, dtype=self.dtype
47 | )
48 | )
49 | )
50 | # Make norm of W_dec 0.1 for each column, separate per layer
51 | self.W_dec.data = (
52 | self.W_dec.data / self.W_dec.data.norm(dim=-1, keepdim=True) * self.cfg["dec_init_norm"]
53 | )
54 | # Initialise W_enc to be the transpose of W_dec
55 | self.W_enc.data = einops.rearrange(
56 | self.W_dec.data.clone(),
57 | "d_hidden n_models d_model -> n_models d_model d_hidden",
58 | )
59 | self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=self.dtype))
60 | self.b_dec = nn.Parameter(
61 | torch.zeros((2, d_in), dtype=self.dtype)
62 | )
63 | self.d_hidden = d_hidden
64 |
65 | self.to(self.cfg["device"])
66 | self.save_dir = None
67 | self.save_version = 0
68 |
69 | def encode(self, x, apply_relu=True):
70 | # x: [batch, n_models, d_model]
71 | x_enc = einops.einsum(
72 | x,
73 | self.W_enc,
74 | "batch n_models d_model, n_models d_model d_hidden -> batch d_hidden",
75 | )
76 | if apply_relu:
77 | acts = F.relu(x_enc + self.b_enc)
78 | else:
79 | acts = x_enc + self.b_enc
80 | return acts
81 |
82 | def decode(self, acts):
83 | # acts: [batch, d_hidden]
84 | acts_dec = einops.einsum(
85 | acts,
86 | self.W_dec,
87 | "batch d_hidden, d_hidden n_models d_model -> batch n_models d_model",
88 | )
89 | return acts_dec + self.b_dec
90 |
91 | def forward(self, x):
92 | # x: [batch, n_models, d_model]
93 | acts = self.encode(x)
94 | return self.decode(acts)
95 |
96 | def get_losses(self, x):
97 | # x: [batch, n_models, d_model]
98 | x = x.to(self.dtype)
99 | acts = self.encode(x)
100 | # acts: [batch, d_hidden]
101 | x_reconstruct = self.decode(acts)
102 | diff = x_reconstruct.float() - x.float()
103 | squared_diff = diff.pow(2)
104 | l2_per_batch = einops.reduce(squared_diff, 'batch n_models d_model -> batch', 'sum')
105 | l2_loss = l2_per_batch.mean()
106 |
107 | total_variance = einops.reduce((x - x.mean(0)).pow(2), 'batch n_models d_model -> batch', 'sum')
108 | explained_variance = 1 - l2_per_batch / total_variance
109 |
110 | per_token_l2_loss_A = (x_reconstruct[:, 0, :] - x[:, 0, :]).pow(2).sum(dim=-1).squeeze()
111 | total_variance_A = (x[:, 0, :] - x[:, 0, :].mean(0)).pow(2).sum(-1).squeeze()
112 | explained_variance_A = 1 - per_token_l2_loss_A / total_variance_A
113 |
114 | per_token_l2_loss_B = (x_reconstruct[:, 1, :] - x[:, 1, :]).pow(2).sum(dim=-1).squeeze()
115 | total_variance_B = (x[:, 1, :] - x[:, 1, :].mean(0)).pow(2).sum(-1).squeeze()
116 | explained_variance_B = 1 - per_token_l2_loss_B / total_variance_B
117 |
118 | decoder_norms = self.W_dec.norm(dim=-1)
119 | # decoder_norms: [d_hidden, n_models]
120 | total_decoder_norm = einops.reduce(decoder_norms, 'd_hidden n_models -> d_hidden', 'sum')
121 | l1_loss = (acts * total_decoder_norm[None, :]).sum(-1).mean(0)
122 |
123 | l0_loss = (acts>0).float().sum(-1).mean()
124 |
125 | return LossOutput(l2_loss=l2_loss, l1_loss=l1_loss, l0_loss=l0_loss, explained_variance=explained_variance, explained_variance_A=explained_variance_A, explained_variance_B=explained_variance_B)
126 |
127 | def create_save_dir(self):
128 | base_dir = Path("/workspace/crosscoder-model-diff-replication/checkpoints")
129 | version_list = [
130 | int(file.name.split("_")[1])
131 | for file in list(SAVE_DIR.iterdir())
132 | if "version" in str(file)
133 | ]
134 | if len(version_list):
135 | version = 1 + max(version_list)
136 | else:
137 | version = 0
138 | self.save_dir = base_dir / f"version_{version}"
139 | self.save_dir.mkdir(parents=True)
140 |
141 | def save(self):
142 | if self.save_dir is None:
143 | self.create_save_dir()
144 | weight_path = self.save_dir / f"{self.save_version}.pt"
145 | cfg_path = self.save_dir / f"{self.save_version}_cfg.json"
146 |
147 | torch.save(self.state_dict(), weight_path)
148 | with open(cfg_path, "w") as f:
149 | json.dump(self.cfg, f)
150 |
151 | print(f"Saved as version {self.save_version} in {self.save_dir}")
152 | self.save_version += 1
153 |
154 | @classmethod
155 | def load_from_hf(
156 | cls,
157 | repo_id: str = "ckkissane/crosscoder-gemma-2-2b-model-diff",
158 | path: str = "blocks.14.hook_resid_pre",
159 | device: Optional[Union[str, torch.device]] = None
160 | ) -> "CrossCoder":
161 | """
162 | Load CrossCoder weights and config from HuggingFace.
163 |
164 | Args:
165 | repo_id: HuggingFace repository ID
166 | path: Path within the repo to the weights/config
167 | model: The transformer model instance needed for initialization
168 | device: Device to load the model to (defaults to cfg device if not specified)
169 |
170 | Returns:
171 | Initialized CrossCoder instance
172 | """
173 |
174 | # Download config and weights
175 | config_path = hf_hub_download(
176 | repo_id=repo_id,
177 | filename=f"{path}/cfg.json"
178 | )
179 | weights_path = hf_hub_download(
180 | repo_id=repo_id,
181 | filename=f"{path}/cc_weights.pt"
182 | )
183 |
184 | # Load config
185 | with open(config_path, 'r') as f:
186 | cfg = json.load(f)
187 |
188 | # Override device if specified
189 | if device is not None:
190 | cfg["device"] = str(device)
191 |
192 | # Initialize CrossCoder with config
193 | instance = cls(cfg)
194 |
195 | # Load weights
196 | state_dict = torch.load(weights_path, map_location=cfg["device"])
197 | instance.load_state_dict(state_dict)
198 |
199 | return instance
200 |
201 | @classmethod
202 | def load(cls, version_dir, checkpoint_version):
203 | save_dir = Path("/workspace/crosscoder-model-diff-replication/checkpoints") / str(version_dir)
204 | cfg_path = save_dir / f"{str(checkpoint_version)}_cfg.json"
205 | weight_path = save_dir / f"{str(checkpoint_version)}.pt"
206 |
207 | cfg = json.load(open(cfg_path, "r"))
208 | pprint.pprint(cfg)
209 | self = cls(cfg=cfg)
210 | self.load_state_dict(torch.load(weight_path))
211 | return self
--------------------------------------------------------------------------------