├── .gitignore ├── train.py ├── analysis.py ├── README.md ├── trainer.py ├── buffer.py ├── utils.py └── crosscoder.py /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/* 2 | __pycache__/* -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from utils import * 3 | from trainer import Trainer 4 | # %% 5 | device = 'cuda:0' 6 | 7 | base_model = HookedTransformer.from_pretrained( 8 | "gemma-2-2b", 9 | device=device, 10 | ) 11 | 12 | chat_model = HookedTransformer.from_pretrained( 13 | "gemma-2-2b-it", 14 | device=device, 15 | ) 16 | 17 | # %% 18 | all_tokens = load_pile_lmsys_mixed_tokens() 19 | 20 | # %% 21 | default_cfg = { 22 | "seed": 49, 23 | "batch_size": 4096, 24 | "buffer_mult": 128, 25 | "lr": 5e-5, 26 | "num_tokens": 400_000_000, 27 | "l1_coeff": 2, 28 | "beta1": 0.9, 29 | "beta2": 0.999, 30 | "d_in": base_model.cfg.d_model, 31 | "dict_size": 2**14, 32 | "seq_len": 1024, 33 | "enc_dtype": "fp32", 34 | "model_name": "gemma-2-2b", 35 | "site": "resid_pre", 36 | "device": "cuda:0", 37 | "model_batch_size": 4, 38 | "log_every": 100, 39 | "save_every": 30000, 40 | "dec_init_norm": 0.08, 41 | "hook_point": "blocks.14.hook_resid_pre", 42 | "wandb_project": "YOUR_WANDB_PROJECT", 43 | "wandb_entity": "YOUR_WANDB_ENTITY", 44 | } 45 | cfg = arg_parse_update_cfg(default_cfg) 46 | 47 | trainer = Trainer(cfg, base_model, chat_model, all_tokens) 48 | trainer.train() 49 | # %% -------------------------------------------------------------------------------- /analysis.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from utils import * 3 | from crosscoder import CrossCoder 4 | torch.set_grad_enabled(False); 5 | # %% 6 | cross_coder = CrossCoder.load_from_hf() 7 | 8 | # %% 9 | norms = cross_coder.W_dec.norm(dim=-1) 10 | norms.shape 11 | # %% 12 | relative_norms = norms[:, 1] / norms.sum(dim=-1) 13 | relative_norms.shape 14 | # %% 15 | 16 | fig = px.histogram( 17 | relative_norms.detach().cpu().numpy(), 18 | title="Gemma 2 2B Base vs IT Model Diff", 19 | labels={"value": "Relative decoder norm strength"}, 20 | nbins=200, 21 | ) 22 | 23 | fig.update_layout(showlegend=False) 24 | fig.update_yaxes(title_text="Number of Latents") 25 | 26 | # Update x-axis ticks 27 | fig.update_xaxes( 28 | tickvals=[0, 0.25, 0.5, 0.75, 1.0], 29 | ticktext=['0', '0.25', '0.5', '0.75', '1.0'] 30 | ) 31 | 32 | fig.show() 33 | 34 | # %% 35 | shared_latent_mask = (relative_norms < 0.7) & (relative_norms > 0.3) 36 | shared_latent_mask.shape 37 | # %% 38 | # Cosine similarity of recoder vectors between models 39 | 40 | cosine_sims = (cross_coder.W_dec[:, 0, :] * cross_coder.W_dec[:, 1, :]).sum(dim=-1) / (cross_coder.W_dec[:, 0, :].norm(dim=-1) * cross_coder.W_dec[:, 1, :].norm(dim=-1)) 41 | cosine_sims.shape 42 | # %% 43 | import plotly.express as px 44 | import torch 45 | 46 | fig = px.histogram( 47 | cosine_sims[shared_latent_mask].to(torch.float32).detach().cpu().numpy(), 48 | #title="Cosine similarity of decoder vectors between models", 49 | log_y=True, # Sets the y-axis to log scale 50 | range_x=[-1, 1], # Sets the x-axis range from -1 to 1 51 | nbins=100, # Adjust this value to change the number of bins 52 | labels={"value": "Cosine similarity of decoder vectors between models"} 53 | ) 54 | 55 | fig.update_layout(showlegend=False) 56 | fig.update_yaxes(title_text="Number of Latents (log scale)") 57 | 58 | fig.show() 59 | # %% 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TLDR 2 | 3 | Open source replication of [Anthropic's Crosscoders for Model Diffing](https://transformer-circuits.pub/2024/crosscoders/index.html#model-diffing). 4 | The crosscoder was trained to model diff the Gemma-2 2b base and IT residual stream at the middle layer. 5 | 6 | See this [blog post](https://www.lesswrong.com/posts/srt6JXsRMtmqAJavD/open-source-replication-of-anthropic-s-crosscoder-paper-for) for more details. 7 | 8 | # Reading This Codebase 9 | 10 | This implementation is adapted from Neel Nanda's code at https://github.com/neelnanda-io/Crosscoders 11 | 12 | * `train.py` is the main training script for the crosscoder. Run it with `python train.py`. 13 | * `trainer.py` contains the pytorch boilerplate code that actually trains the crosscoder. 14 | * `crosscoder.py` contains the pytorch implementation of the crosscoder. It was implemented to diff two different models, but should be easily hackable to work with an arbitrary number of models. 15 | * `buffer.py` contains code to extract activations from both models, concatenate them, and store them in a buffer which is shuffled and periodically refreshed during training. 16 | * `analysis.py` is a short notebook replicating some of the results from the Anthropic paper. See this [colab notebook](https://colab.research.google.com/drive/124ODki4dUjfi21nuZPHRySALx9I74YHj?usp=sharing) for a more comprehensive demo. 17 | 18 | It won't work out of the box, but hopefully it's pretty hackable. Some tips: 19 | * In `train.py` I just set the cfg by editing the code, rather than using command line arguments. You'll need to change the "wandb_entity" and "wandb_entity" in the cfg dict. 20 | * You'll need to create a checkpoints dir in `/workspace/crosscoder-model-diff-replication/checkpoints` (or change this path in the code). I would sanity check this with a short test run to make sure your weights will be properly saved at the end of training. 21 | * We load training data from https://huggingface.co/datasets/ckkissane/pile-lmsys-mix-1m-tokenized-gemma-2 as a global tensor called all_tokens, and pass this to the Trainer in `train.py`. This is very hacky, but should be easy to swap out if needed. 22 | * In `buffer.py` we separately normalize both the base and chat activations such that they both have average norm sqrt(d_model). This should be handled for you during training, but note that you'll also need to normalize activations during analysis (or fold the normalization scaling factors into the crosscoder weights). 23 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from crosscoder import CrossCoder 3 | from buffer import Buffer 4 | import tqdm 5 | 6 | from torch.nn.utils import clip_grad_norm_ 7 | class Trainer: 8 | def __init__(self, cfg, model_A, model_B, all_tokens): 9 | self.cfg = cfg 10 | self.model_A = model_A 11 | self.model_B = model_B 12 | self.crosscoder = CrossCoder(cfg) 13 | self.buffer = Buffer(cfg, model_A, model_B, all_tokens) 14 | self.total_steps = cfg["num_tokens"] // cfg["batch_size"] 15 | 16 | self.optimizer = torch.optim.Adam( 17 | self.crosscoder.parameters(), 18 | lr=cfg["lr"], 19 | betas=(cfg["beta1"], cfg["beta2"]), 20 | ) 21 | self.scheduler = torch.optim.lr_scheduler.LambdaLR( 22 | self.optimizer, self.lr_lambda 23 | ) 24 | self.step_counter = 0 25 | 26 | wandb.init(project=cfg["wandb_project"], entity=cfg["wandb_entity"]) 27 | 28 | def lr_lambda(self, step): 29 | if step < 0.8 * self.total_steps: 30 | return 1.0 31 | else: 32 | return 1.0 - (step - 0.8 * self.total_steps) / (0.2 * self.total_steps) 33 | 34 | def get_l1_coeff(self): 35 | # Linearly increases from 0 to cfg["l1_coeff"] over the first 0.05 * self.total_steps steps, then keeps it constant 36 | if self.step_counter < 0.05 * self.total_steps: 37 | return self.cfg["l1_coeff"] * self.step_counter / (0.05 * self.total_steps) 38 | else: 39 | return self.cfg["l1_coeff"] 40 | 41 | def step(self): 42 | acts = self.buffer.next() 43 | losses = self.crosscoder.get_losses(acts) 44 | loss = losses.l2_loss + self.get_l1_coeff() * losses.l1_loss 45 | loss.backward() 46 | clip_grad_norm_(self.crosscoder.parameters(), max_norm=1.0) 47 | self.optimizer.step() 48 | self.scheduler.step() 49 | self.optimizer.zero_grad() 50 | 51 | loss_dict = { 52 | "loss": loss.item(), 53 | "l2_loss": losses.l2_loss.item(), 54 | "l1_loss": losses.l1_loss.item(), 55 | "l0_loss": losses.l0_loss.item(), 56 | "l1_coeff": self.get_l1_coeff(), 57 | "lr": self.scheduler.get_last_lr()[0], 58 | "explained_variance": losses.explained_variance.mean().item(), 59 | "explained_variance_A": losses.explained_variance_A.mean().item(), 60 | "explained_variance_B": losses.explained_variance_B.mean().item(), 61 | } 62 | self.step_counter += 1 63 | return loss_dict 64 | 65 | def log(self, loss_dict): 66 | wandb.log(loss_dict, step=self.step_counter) 67 | print(loss_dict) 68 | 69 | def save(self): 70 | self.crosscoder.save() 71 | 72 | def train(self): 73 | self.step_counter = 0 74 | try: 75 | for i in tqdm.trange(self.total_steps): 76 | loss_dict = self.step() 77 | if i % self.cfg["log_every"] == 0: 78 | self.log(loss_dict) 79 | if (i + 1) % self.cfg["save_every"] == 0: 80 | self.save() 81 | finally: 82 | self.save() -------------------------------------------------------------------------------- /buffer.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from transformer_lens import ActivationCache 3 | import tqdm 4 | 5 | class Buffer: 6 | """ 7 | This defines a data buffer, to store a stack of acts across both model that can be used to train the autoencoder. It'll automatically run the model to generate more when it gets halfway empty. 8 | """ 9 | 10 | def __init__(self, cfg, model_A, model_B, all_tokens): 11 | assert model_A.cfg.d_model == model_B.cfg.d_model 12 | self.cfg = cfg 13 | self.buffer_size = cfg["batch_size"] * cfg["buffer_mult"] 14 | self.buffer_batches = self.buffer_size // (cfg["seq_len"] - 1) 15 | self.buffer_size = self.buffer_batches * (cfg["seq_len"] - 1) 16 | self.buffer = torch.zeros( 17 | (self.buffer_size, 2, model_A.cfg.d_model), 18 | dtype=torch.bfloat16, 19 | requires_grad=False, 20 | ).to(cfg["device"]) # hardcoding 2 for model diffing 21 | self.cfg = cfg 22 | self.model_A = model_A 23 | self.model_B = model_B 24 | self.token_pointer = 0 25 | self.first = True 26 | self.normalize = True 27 | self.all_tokens = all_tokens 28 | 29 | estimated_norm_scaling_factor_A = self.estimate_norm_scaling_factor(cfg["model_batch_size"], model_A) 30 | estimated_norm_scaling_factor_B = self.estimate_norm_scaling_factor(cfg["model_batch_size"], model_B) 31 | 32 | self.normalisation_factor = torch.tensor( 33 | [ 34 | estimated_norm_scaling_factor_A, 35 | estimated_norm_scaling_factor_B, 36 | ], 37 | device="cuda:0", 38 | dtype=torch.float32, 39 | ) 40 | self.refresh() 41 | 42 | @torch.no_grad() 43 | def estimate_norm_scaling_factor(self, batch_size, model, n_batches_for_norm_estimate: int = 100): 44 | # stolen from SAELens https://github.com/jbloomAus/SAELens/blob/6d6eaef343fd72add6e26d4c13307643a62c41bf/sae_lens/training/activations_store.py#L370 45 | norms_per_batch = [] 46 | for i in tqdm.tqdm( 47 | range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor" 48 | ): 49 | tokens = self.all_tokens[i * batch_size : (i + 1) * batch_size] 50 | _, cache = model.run_with_cache( 51 | tokens, 52 | names_filter=self.cfg["hook_point"], 53 | return_type=None, 54 | ) 55 | acts = cache[self.cfg["hook_point"]] 56 | # TODO: maybe drop BOS here 57 | norms_per_batch.append(acts.norm(dim=-1).mean().item()) 58 | mean_norm = np.mean(norms_per_batch) 59 | scaling_factor = np.sqrt(model.cfg.d_model) / mean_norm 60 | 61 | return scaling_factor 62 | 63 | @torch.no_grad() 64 | def refresh(self): 65 | self.pointer = 0 66 | print("Refreshing the buffer!") 67 | with torch.autocast("cuda", torch.bfloat16): 68 | if self.first: 69 | num_batches = self.buffer_batches 70 | else: 71 | num_batches = self.buffer_batches // 2 72 | self.first = False 73 | for _ in tqdm.trange(0, num_batches, self.cfg["model_batch_size"]): 74 | tokens = self.all_tokens[ 75 | self.token_pointer : min( 76 | self.token_pointer + self.cfg["model_batch_size"], num_batches 77 | ) 78 | ] 79 | _, cache_A = self.model_A.run_with_cache( 80 | tokens, names_filter=self.cfg["hook_point"] 81 | ) 82 | cache_A: ActivationCache 83 | 84 | _, cache_B = self.model_B.run_with_cache( 85 | tokens, names_filter=self.cfg["hook_point"] 86 | ) 87 | cache_B: ActivationCache 88 | 89 | acts = torch.stack([cache_A[self.cfg["hook_point"]], cache_B[self.cfg["hook_point"]]], dim=0) 90 | acts = acts[:, :, 1:, :] # Drop BOS 91 | assert acts.shape == (2, tokens.shape[0], tokens.shape[1]-1, self.model_A.cfg.d_model) # [2, batch, seq_len, d_model] 92 | acts = einops.rearrange( 93 | acts, 94 | "n_layers batch seq_len d_model -> (batch seq_len) n_layers d_model", 95 | ) 96 | 97 | self.buffer[self.pointer : self.pointer + acts.shape[0]] = acts 98 | self.pointer += acts.shape[0] 99 | self.token_pointer += self.cfg["model_batch_size"] 100 | 101 | self.pointer = 0 102 | self.buffer = self.buffer[ 103 | torch.randperm(self.buffer.shape[0]).to(self.cfg["device"]) 104 | ] 105 | 106 | @torch.no_grad() 107 | def next(self): 108 | out = self.buffer[self.pointer : self.pointer + self.cfg["batch_size"]].float() 109 | # out: [batch_size, n_layers, d_model] 110 | self.pointer += self.cfg["batch_size"] 111 | if self.pointer > self.buffer.shape[0] // 2 - self.cfg["batch_size"]: 112 | self.refresh() 113 | if self.normalize: 114 | out = out * self.normalisation_factor[None, :, None] 115 | return out 116 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | from IPython import get_ipython 4 | 5 | ipython = get_ipython() 6 | # Code to automatically update the HookedTransformer code as its edited without restarting the kernel 7 | if ipython is not None: 8 | ipython.magic("load_ext autoreload") 9 | ipython.magic("autoreload 2") 10 | 11 | import plotly.io as pio 12 | pio.renderers.default = "jupyterlab" 13 | 14 | # Import stuff 15 | import einops 16 | import json 17 | import argparse 18 | 19 | from datasets import load_dataset 20 | from pathlib import Path 21 | import plotly.express as px 22 | from torch.distributions.categorical import Categorical 23 | from tqdm import tqdm 24 | import torch 25 | import numpy as np 26 | from transformer_lens import HookedTransformer 27 | from jaxtyping import Float 28 | from transformer_lens.hook_points import HookPoint 29 | 30 | from functools import partial 31 | 32 | from IPython.display import HTML 33 | 34 | from transformer_lens.utils import to_numpy 35 | import pandas as pd 36 | 37 | from html import escape 38 | import colorsys 39 | 40 | 41 | import wandb 42 | 43 | import plotly.graph_objects as go 44 | 45 | update_layout_set = { 46 | "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis", 47 | "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid", 48 | "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth" 49 | } 50 | 51 | def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs): 52 | if isinstance(tensor, list): 53 | tensor = torch.stack(tensor) 54 | kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set} 55 | kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set} 56 | if "facet_labels" in kwargs_pre: 57 | facet_labels = kwargs_pre.pop("facet_labels") 58 | else: 59 | facet_labels = None 60 | if "color_continuous_scale" not in kwargs_pre: 61 | kwargs_pre["color_continuous_scale"] = "RdBu" 62 | fig = px.imshow(to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post) 63 | if facet_labels: 64 | for i, label in enumerate(facet_labels): 65 | fig.layout.annotations[i]['text'] = label 66 | 67 | fig.show(renderer) 68 | 69 | def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs): 70 | px.line(y=to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer) 71 | 72 | def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, return_fig=False, **kwargs): 73 | x = to_numpy(x) 74 | y = to_numpy(y) 75 | fig = px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs) 76 | if return_fig: 77 | return fig 78 | fig.show(renderer) 79 | 80 | def lines(lines_list, x=None, mode='lines', labels=None, xaxis='', yaxis='', title = '', log_y=False, hover=None, **kwargs): 81 | # Helper function to plot multiple lines 82 | if type(lines_list)==torch.Tensor: 83 | lines_list = [lines_list[i] for i in range(lines_list.shape[0])] 84 | if x is None: 85 | x=np.arange(len(lines_list[0])) 86 | fig = go.Figure(layout={'title':title}) 87 | fig.update_xaxes(title=xaxis) 88 | fig.update_yaxes(title=yaxis) 89 | for c, line in enumerate(lines_list): 90 | if type(line)==torch.Tensor: 91 | line = to_numpy(line) 92 | if labels is not None: 93 | label = labels[c] 94 | else: 95 | label = c 96 | fig.add_trace(go.Scatter(x=x, y=line, mode=mode, name=label, hovertext=hover, **kwargs)) 97 | if log_y: 98 | fig.update_layout(yaxis_type="log") 99 | fig.show() 100 | 101 | def bar(tensor, renderer=None, xaxis="", yaxis="", **kwargs): 102 | px.bar( 103 | y=to_numpy(tensor), 104 | labels={"x": xaxis, "y": yaxis}, 105 | template="simple_white", 106 | **kwargs).show(renderer) 107 | 108 | def create_html(strings, values, saturation=0.5, allow_different_length=False): 109 | # escape strings to deal with tabs, newlines, etc. 110 | escaped_strings = [escape(s, quote=True) for s in strings] 111 | processed_strings = [ 112 | s.replace("\n", "
").replace("\t", " ").replace(" ", " ") 113 | for s in escaped_strings 114 | ] 115 | 116 | if isinstance(values, torch.Tensor) and len(values.shape)>1: 117 | values = values.flatten().tolist() 118 | 119 | if not allow_different_length: 120 | assert len(processed_strings) == len(values) 121 | 122 | # scale values 123 | max_value = max(max(values), -min(values))+1e-3 124 | scaled_values = [v / max_value * saturation for v in values] 125 | 126 | # create html 127 | html = "" 128 | for i, s in enumerate(processed_strings): 129 | if i{s}' 146 | 147 | display(HTML(html)) 148 | 149 | # crosscoder stuff 150 | 151 | def arg_parse_update_cfg(default_cfg): 152 | """ 153 | Helper function to take in a dictionary of arguments, convert these to command line arguments, look at what was passed in, and return an updated dictionary. 154 | 155 | If in Ipython, just returns with no changes 156 | """ 157 | if get_ipython() is not None: 158 | # Is in IPython 159 | print("In IPython - skipped argparse") 160 | return default_cfg 161 | cfg = dict(default_cfg) 162 | parser = argparse.ArgumentParser() 163 | for key, value in default_cfg.items(): 164 | if type(value) == bool: 165 | # argparse for Booleans is broken rip. Now you put in a flag to change the default --{flag} to set True, --{flag} to set False 166 | if value: 167 | parser.add_argument(f"--{key}", action="store_false") 168 | else: 169 | parser.add_argument(f"--{key}", action="store_true") 170 | 171 | else: 172 | parser.add_argument(f"--{key}", type=type(value), default=value) 173 | args = parser.parse_args() 174 | parsed_args = vars(args) 175 | cfg.update(parsed_args) 176 | print("Updated config") 177 | print(json.dumps(cfg, indent=2)) 178 | return cfg 179 | 180 | def load_pile_lmsys_mixed_tokens(): 181 | try: 182 | print("Loading data from disk") 183 | all_tokens = torch.load("/workspace/data/pile-lmsys-mix-1m-tokenized-gemma-2.pt") 184 | except: 185 | print("Data is not cached. Loading data from HF") 186 | data = load_dataset( 187 | "ckkissane/pile-lmsys-mix-1m-tokenized-gemma-2", 188 | split="train", 189 | cache_dir="/workspace/cache/" 190 | ) 191 | data.save_to_disk("/workspace/data/pile-lmsys-mix-1m-tokenized-gemma-2.hf") 192 | data.set_format(type="torch", columns=["input_ids"]) 193 | all_tokens = data["input_ids"] 194 | torch.save(all_tokens, "/workspace/data/pile-lmsys-mix-1m-tokenized-gemma-2.pt") 195 | print(f"Saved tokens to disk") 196 | return all_tokens -------------------------------------------------------------------------------- /crosscoder.py: -------------------------------------------------------------------------------- 1 | 2 | from utils import * 3 | 4 | from torch import nn 5 | import pprint 6 | import torch.nn.functional as F 7 | from typing import Optional, Union 8 | from huggingface_hub import hf_hub_download 9 | 10 | from typing import NamedTuple 11 | 12 | DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} 13 | SAVE_DIR = Path("/workspace/crosscoder-model-diff-replication/checkpoints") 14 | 15 | class LossOutput(NamedTuple): 16 | # loss: torch.Tensor 17 | l2_loss: torch.Tensor 18 | l1_loss: torch.Tensor 19 | l0_loss: torch.Tensor 20 | explained_variance: torch.Tensor 21 | explained_variance_A: torch.Tensor 22 | explained_variance_B: torch.Tensor 23 | 24 | class CrossCoder(nn.Module): 25 | def __init__(self, cfg): 26 | super().__init__() 27 | self.cfg = cfg 28 | d_hidden = self.cfg["dict_size"] 29 | d_in = self.cfg["d_in"] 30 | self.dtype = DTYPES[self.cfg["enc_dtype"]] 31 | torch.manual_seed(self.cfg["seed"]) 32 | # hardcoding n_models to 2 33 | self.W_enc = nn.Parameter( 34 | torch.empty(2, d_in, d_hidden, dtype=self.dtype) 35 | ) 36 | self.W_dec = nn.Parameter( 37 | torch.nn.init.normal_( 38 | torch.empty( 39 | d_hidden, 2, d_in, dtype=self.dtype 40 | ) 41 | ) 42 | ) 43 | self.W_dec = nn.Parameter( 44 | torch.nn.init.normal_( 45 | torch.empty( 46 | d_hidden, 2, d_in, dtype=self.dtype 47 | ) 48 | ) 49 | ) 50 | # Make norm of W_dec 0.1 for each column, separate per layer 51 | self.W_dec.data = ( 52 | self.W_dec.data / self.W_dec.data.norm(dim=-1, keepdim=True) * self.cfg["dec_init_norm"] 53 | ) 54 | # Initialise W_enc to be the transpose of W_dec 55 | self.W_enc.data = einops.rearrange( 56 | self.W_dec.data.clone(), 57 | "d_hidden n_models d_model -> n_models d_model d_hidden", 58 | ) 59 | self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=self.dtype)) 60 | self.b_dec = nn.Parameter( 61 | torch.zeros((2, d_in), dtype=self.dtype) 62 | ) 63 | self.d_hidden = d_hidden 64 | 65 | self.to(self.cfg["device"]) 66 | self.save_dir = None 67 | self.save_version = 0 68 | 69 | def encode(self, x, apply_relu=True): 70 | # x: [batch, n_models, d_model] 71 | x_enc = einops.einsum( 72 | x, 73 | self.W_enc, 74 | "batch n_models d_model, n_models d_model d_hidden -> batch d_hidden", 75 | ) 76 | if apply_relu: 77 | acts = F.relu(x_enc + self.b_enc) 78 | else: 79 | acts = x_enc + self.b_enc 80 | return acts 81 | 82 | def decode(self, acts): 83 | # acts: [batch, d_hidden] 84 | acts_dec = einops.einsum( 85 | acts, 86 | self.W_dec, 87 | "batch d_hidden, d_hidden n_models d_model -> batch n_models d_model", 88 | ) 89 | return acts_dec + self.b_dec 90 | 91 | def forward(self, x): 92 | # x: [batch, n_models, d_model] 93 | acts = self.encode(x) 94 | return self.decode(acts) 95 | 96 | def get_losses(self, x): 97 | # x: [batch, n_models, d_model] 98 | x = x.to(self.dtype) 99 | acts = self.encode(x) 100 | # acts: [batch, d_hidden] 101 | x_reconstruct = self.decode(acts) 102 | diff = x_reconstruct.float() - x.float() 103 | squared_diff = diff.pow(2) 104 | l2_per_batch = einops.reduce(squared_diff, 'batch n_models d_model -> batch', 'sum') 105 | l2_loss = l2_per_batch.mean() 106 | 107 | total_variance = einops.reduce((x - x.mean(0)).pow(2), 'batch n_models d_model -> batch', 'sum') 108 | explained_variance = 1 - l2_per_batch / total_variance 109 | 110 | per_token_l2_loss_A = (x_reconstruct[:, 0, :] - x[:, 0, :]).pow(2).sum(dim=-1).squeeze() 111 | total_variance_A = (x[:, 0, :] - x[:, 0, :].mean(0)).pow(2).sum(-1).squeeze() 112 | explained_variance_A = 1 - per_token_l2_loss_A / total_variance_A 113 | 114 | per_token_l2_loss_B = (x_reconstruct[:, 1, :] - x[:, 1, :]).pow(2).sum(dim=-1).squeeze() 115 | total_variance_B = (x[:, 1, :] - x[:, 1, :].mean(0)).pow(2).sum(-1).squeeze() 116 | explained_variance_B = 1 - per_token_l2_loss_B / total_variance_B 117 | 118 | decoder_norms = self.W_dec.norm(dim=-1) 119 | # decoder_norms: [d_hidden, n_models] 120 | total_decoder_norm = einops.reduce(decoder_norms, 'd_hidden n_models -> d_hidden', 'sum') 121 | l1_loss = (acts * total_decoder_norm[None, :]).sum(-1).mean(0) 122 | 123 | l0_loss = (acts>0).float().sum(-1).mean() 124 | 125 | return LossOutput(l2_loss=l2_loss, l1_loss=l1_loss, l0_loss=l0_loss, explained_variance=explained_variance, explained_variance_A=explained_variance_A, explained_variance_B=explained_variance_B) 126 | 127 | def create_save_dir(self): 128 | base_dir = Path("/workspace/crosscoder-model-diff-replication/checkpoints") 129 | version_list = [ 130 | int(file.name.split("_")[1]) 131 | for file in list(SAVE_DIR.iterdir()) 132 | if "version" in str(file) 133 | ] 134 | if len(version_list): 135 | version = 1 + max(version_list) 136 | else: 137 | version = 0 138 | self.save_dir = base_dir / f"version_{version}" 139 | self.save_dir.mkdir(parents=True) 140 | 141 | def save(self): 142 | if self.save_dir is None: 143 | self.create_save_dir() 144 | weight_path = self.save_dir / f"{self.save_version}.pt" 145 | cfg_path = self.save_dir / f"{self.save_version}_cfg.json" 146 | 147 | torch.save(self.state_dict(), weight_path) 148 | with open(cfg_path, "w") as f: 149 | json.dump(self.cfg, f) 150 | 151 | print(f"Saved as version {self.save_version} in {self.save_dir}") 152 | self.save_version += 1 153 | 154 | @classmethod 155 | def load_from_hf( 156 | cls, 157 | repo_id: str = "ckkissane/crosscoder-gemma-2-2b-model-diff", 158 | path: str = "blocks.14.hook_resid_pre", 159 | device: Optional[Union[str, torch.device]] = None 160 | ) -> "CrossCoder": 161 | """ 162 | Load CrossCoder weights and config from HuggingFace. 163 | 164 | Args: 165 | repo_id: HuggingFace repository ID 166 | path: Path within the repo to the weights/config 167 | model: The transformer model instance needed for initialization 168 | device: Device to load the model to (defaults to cfg device if not specified) 169 | 170 | Returns: 171 | Initialized CrossCoder instance 172 | """ 173 | 174 | # Download config and weights 175 | config_path = hf_hub_download( 176 | repo_id=repo_id, 177 | filename=f"{path}/cfg.json" 178 | ) 179 | weights_path = hf_hub_download( 180 | repo_id=repo_id, 181 | filename=f"{path}/cc_weights.pt" 182 | ) 183 | 184 | # Load config 185 | with open(config_path, 'r') as f: 186 | cfg = json.load(f) 187 | 188 | # Override device if specified 189 | if device is not None: 190 | cfg["device"] = str(device) 191 | 192 | # Initialize CrossCoder with config 193 | instance = cls(cfg) 194 | 195 | # Load weights 196 | state_dict = torch.load(weights_path, map_location=cfg["device"]) 197 | instance.load_state_dict(state_dict) 198 | 199 | return instance 200 | 201 | @classmethod 202 | def load(cls, version_dir, checkpoint_version): 203 | save_dir = Path("/workspace/crosscoder-model-diff-replication/checkpoints") / str(version_dir) 204 | cfg_path = save_dir / f"{str(checkpoint_version)}_cfg.json" 205 | weight_path = save_dir / f"{str(checkpoint_version)}.pt" 206 | 207 | cfg = json.load(open(cfg_path, "r")) 208 | pprint.pprint(cfg) 209 | self = cls(cfg=cfg) 210 | self.load_state_dict(torch.load(weight_path)) 211 | return self --------------------------------------------------------------------------------