├── .gitignore ├── LICENSE ├── README.md ├── __pycache__ └── utils.cpython-310.pyc ├── crosscoders ├── __init__.py ├── train.py └── utils.py ├── rare_freq_dir.pt ├── scratch.py ├── scratch_2.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Neel Nanda 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Code to train a GPT-2 Small acausal crosscoder 2 | 3 | Details (written in the context of someone training a model diff crosscoder): 4 | * I based this on my initial replication of Towards Monosemanticity, not any of the SAE libraries (seemed clunkier to adapt those). Not sure this was the right call 5 | * Key files: utils.py has basically everything important, train.py is a tiny file that calls and runs the trainer, and calling it with eg --l1_coeff=2 will update the config used, scratch.py has some analysis code. Ignore all other files as irrelevant 6 | * I decided to implement it with W_enc having shape [n_layers, d_model, d_sae] and W_dec having shape [d_sae, n_layers, d_model] (here you'd change n_layers to two). It'd also be reasonable to implement it by flattening it into a n_layers * d_model axis and just having a funkier loss function that needs to unflatten, but this felt more elegant to me 7 | * I followed the Anthropic April update method and some adaptions from the crosscoder post like the loss function 8 | * I separately computed and hard coded the normalisation factors, I think these are fairly important. Probably less so here since base and chat should have v similar norms(?) 9 | * This is using ReLU and L1 - I expect topK or JumpReLU would just be better (no shrinkage or "needing to have small activations" issues) and basically work fine, though Anthropic did say something about the weird loss (sum of L2 norm of each layer) incentivising layer sparsity, which may be lost with those? It's probably fine to stick with it as is. Gated with their L1 loss variant may also be fine, idk. 10 | * There's a buffer which runs the model on several batches periodically and stores a shuffled mix of activations and provides them. You'll need to adapt this to run both chat and base (ideally have the same control tokens in both so it's perfectly matched, unless this breaks the base model?) 11 | * I store a pre-tokenized dataset locally and just load it as a global tensor called all_tokens. This is very hacky, but should be easy to swap out 12 | * I found that it was very sensitive to the W_dec init norm - I initially made each d_model vector 0.1 and this went terribly. I think the norm of the flattened vector should probably be 0.1? I just fiddled a bit and found something kinda fine 13 | * Probably not relevant to you, but I found that the crosscoder was much better on earlier layers than later (eg 35% FVU on layer 10, <10% on the first few layers) 14 | 15 | ----- 16 | # This is all old stuff that's probably no longer relevant 17 | # TLDR 18 | 19 | This is an open source replication of [Anthropic's Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features/index.html) paper. The autoencoder was trained on the gelu-1l model in TransformerLens, you can access two trained autoencoders and the model using [this tutorial](https://colab.research.google.com/drive/1u8larhpxy8w4mMsJiSBddNOzFGj7_RTn#scrollTo=MYrIYDEfBtbL). 20 | 21 | # Reading This Codebase 22 | 23 | This is a pretty scrappy training codebase, and won't run from the top. I mostly recommend reading the code and copying snippets. See also [Hoagy Cunningham's Github](https://github.com/HoagyC/sparse_coding). 24 | 25 | * `utils.py` contains various utils to define the Autoencoder, data Buffer and training data. 26 | * Toggle `loading_data_first_time` to True to load and process the text data used to run the model and generate acts 27 | * `train.py` is a scrappy training script 28 | * `cfg["remove_rare_dir"]` was an experiment in training an autoencoder whose features were all orthogonal to the shared direction among rare features, those lines of code can be ignored and weren't used for the open source autoencoders. 29 | * There was a bug in the code to set the decoder weights to have unit norm - it makes the gradients orthogonal, but I forgot to *also* set the norm to be 1 again after each gradient update (turns out a vector of unit norm plus a perpendicular vector does not remain unit norm!). I think I have now fixed the bug. 30 | * `analysis.py` is a scrappy set of experiments for exploring the autoencoder. I recommend reading the Colab tutorial instead for something cleaner and better commented. 31 | 32 | Setup Notes: 33 | 34 | * Create data - you'll need to set the flag loading_data_first_time to True in utils.py , note that this downloads the training mix of gelu-1l and if using eg the Pythia models you'll need different data (I recommend https://huggingface.co/datasets/monology/pile-uncopyrighted ) 35 | * A bunch of folders are hard coded to be /workspace/..., change this for your system. 36 | * Create a checkpoints dir in /workspace/1L-Sparse-Autoencoder/checkpoints 37 | 38 | * If you train an autoencoder and want to share the weights, copy the final checkpoints to a new folder, use upload_folder_to_hf to upload to HuggingFace, create your own repo. Run huggingface-cli login to login, and apt-get install git-lfs and then git lfs install 39 | -------------------------------------------------------------------------------- /__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neelnanda-io/Crosscoders/3adc7eb23a5a56f12557d2c6c206f0aa688bdbd3/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /crosscoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * -------------------------------------------------------------------------------- /crosscoders/train.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from utils import * 3 | trainer = Trainer(cfg, model) 4 | trainer.train() 5 | # %% 6 | -------------------------------------------------------------------------------- /crosscoders/utils.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | 4 | # os.environ["TRANSFORMERS_CACHE"] = "/workspace/cache/" 5 | # os.environ["DATASETS_CACHE"] = "/workspace/cache/" 6 | # %% 7 | import os 8 | import json 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import numpy as np 13 | import random 14 | import tqdm 15 | import pprint 16 | import einops 17 | import wandb 18 | from pathlib import Path 19 | from torch.nn.utils import clip_grad_norm_ 20 | import huggingface_hub 21 | import argparse 22 | from typing import NamedTuple 23 | from datasets import load_dataset 24 | from transformer_lens import HookedTransformer, ActivationCache 25 | from IPython import get_ipython 26 | import wandb 27 | from torch.nn.utils import clip_grad_norm_ 28 | import huggingface_hub 29 | 30 | # %% 31 | import argparse 32 | 33 | 34 | def arg_parse_update_cfg(default_cfg): 35 | """ 36 | Helper function to take in a dictionary of arguments, convert these to command line arguments, look at what was passed in, and return an updated dictionary. 37 | 38 | If in Ipython, just returns with no changes 39 | """ 40 | if get_ipython() is not None: 41 | # Is in IPython 42 | print("In IPython - skipped argparse") 43 | return default_cfg 44 | cfg = dict(default_cfg) 45 | parser = argparse.ArgumentParser() 46 | for key, value in default_cfg.items(): 47 | if type(value) == bool: 48 | # argparse for Booleans is broken rip. Now you put in a flag to change the default --{flag} to set True, --{flag} to set False 49 | if value: 50 | parser.add_argument(f"--{key}", action="store_false") 51 | else: 52 | parser.add_argument(f"--{key}", action="store_true") 53 | 54 | else: 55 | parser.add_argument(f"--{key}", type=type(value), default=value) 56 | args = parser.parse_args() 57 | parsed_args = vars(args) 58 | cfg.update(parsed_args) 59 | print("Updated config") 60 | print(json.dumps(cfg, indent=2)) 61 | return cfg 62 | 63 | 64 | default_cfg = { 65 | "seed": 51, 66 | "batch_size": 2048, 67 | "buffer_mult": 512, 68 | "lr": 2e-5, 69 | "num_tokens": int(4e8), 70 | "l1_coeff": 2, 71 | "beta1": 0.9, 72 | "beta2": 0.999, 73 | "dict_size": 2**16, 74 | "seq_len": 1024, 75 | "enc_dtype": "fp32", 76 | # "remove_rare_dir": False, 77 | "model_name": "gpt2-small", 78 | "site": "resid_post", 79 | # "layer": 0, 80 | "device": "cuda:0", 81 | "model_batch_size": 32, 82 | "log_every": 100, 83 | "save_every": 100000, 84 | "dec_init_norm": 0.005, 85 | } 86 | 87 | cfg = arg_parse_update_cfg(default_cfg) 88 | 89 | 90 | def post_init_cfg(cfg): 91 | cfg["name"] = f"{cfg['model_name']}_{cfg['dict_size']}_{cfg['site']}" 92 | 93 | 94 | post_init_cfg(cfg) 95 | pprint.pprint(cfg) 96 | # %% 97 | 98 | SEED = cfg["seed"] 99 | GENERATOR = torch.manual_seed(SEED) 100 | DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} 101 | np.random.seed(SEED) 102 | random.seed(SEED) 103 | torch.set_grad_enabled(True) 104 | 105 | # model: HookedTransformer = ( 106 | # HookedTransformer.from_pretrained(cfg["model_name"]) 107 | # .to(DTYPES[cfg["enc_dtype"]]) 108 | # .to(cfg["device"]) 109 | # ) 110 | 111 | # n_layers = model.cfg.n_layers 112 | # d_model = model.cfg.d_model 113 | # n_heads = model.cfg.n_heads 114 | # d_head = model.cfg.d_head 115 | # d_mlp = model.cfg.d_mlp 116 | # d_vocab = model.cfg.d_vocab 117 | 118 | # %% 119 | # Replace with your own path 120 | SAVE_DIR = Path("/workspace/Crosscoders/checkpoints") 121 | 122 | from typing import NamedTuple 123 | 124 | 125 | class LossOutput(NamedTuple): 126 | l2_loss: torch.Tensor 127 | l1_loss: torch.Tensor 128 | l0_loss: torch.Tensor 129 | 130 | 131 | class CrossCoder(nn.Module): 132 | def __init__(self, cfg, model): 133 | super().__init__() 134 | self.cfg = cfg 135 | d_hidden = self.cfg["dict_size"] 136 | self.dtype = DTYPES[self.cfg["enc_dtype"]] 137 | torch.manual_seed(self.cfg["seed"]) 138 | self.W_enc = nn.Parameter( 139 | torch.empty( 140 | model.cfg.n_layers, model.cfg.d_model, d_hidden, dtype=self.dtype 141 | ) 142 | ) 143 | self.W_dec = nn.Parameter( 144 | torch.nn.init.normal_( 145 | torch.empty( 146 | d_hidden, model.cfg.n_layers, model.cfg.d_model, dtype=self.dtype 147 | ) 148 | ) 149 | ) 150 | # Make norm of W_dec 0.1 for each column, separate per layer 151 | self.W_dec.data = ( 152 | self.W_dec.data 153 | / self.W_dec.data.norm(dim=-1, keepdim=True) 154 | * self.cfg["dec_init_norm"] 155 | ) 156 | # Initialise W_enc to be the transpose of W_dec 157 | self.W_enc.data = einops.rearrange( 158 | self.W_dec.data.clone(), 159 | "d_hidden n_layers d_model -> n_layers d_model d_hidden", 160 | ) 161 | self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=self.dtype)) 162 | self.b_dec = nn.Parameter( 163 | torch.zeros((model.cfg.n_layers, model.cfg.d_model), dtype=self.dtype) 164 | ) 165 | 166 | self.d_hidden = d_hidden 167 | 168 | self.to(self.cfg["device"]) 169 | 170 | self.save_dir = None 171 | self.save_version = 0 172 | 173 | def encode(self, x, apply_relu=True): 174 | # x: [batch, n_layers, d_model] 175 | x_enc = einops.einsum( 176 | x, 177 | self.W_enc, 178 | "... n_layers d_model, n_layers d_model d_hidden -> ... d_hidden", 179 | ) 180 | if apply_relu: 181 | acts = F.relu(x_enc + self.b_enc) 182 | else: 183 | acts = x_enc + self.b_enc 184 | return acts 185 | 186 | def decode(self, acts): 187 | # acts: [batch, d_hidden] 188 | acts_dec = einops.einsum( 189 | acts, 190 | self.W_dec, 191 | "... d_hidden, d_hidden n_layers d_model -> ... n_layers d_model", 192 | ) 193 | return acts_dec + self.b_dec 194 | 195 | def forward(self, x): 196 | # x: [batch, n_layers, d_model] 197 | acts = self.encode(x) 198 | return self.decode(acts) 199 | 200 | def get_losses(self, x): 201 | # x: [batch, n_layers, d_model] 202 | x = x.to(self.dtype) 203 | acts = self.encode(x) 204 | # acts: [batch, d_hidden] 205 | x_reconstruct = self.decode(acts) 206 | diff = x_reconstruct.float() - x.float() 207 | squared_diff = diff.pow(2) 208 | l2_per_batch = einops.reduce(squared_diff, "... n_layers d_model -> ...", "sum") 209 | l2_loss = l2_per_batch.mean() 210 | 211 | decoder_norms = self.W_dec.norm(dim=-1) 212 | # decoder_norms: [d_hidden, n_layers] 213 | total_decoder_norm = einops.reduce( 214 | decoder_norms, "d_hidden n_layers -> d_hidden", "sum" 215 | ) 216 | l1_loss = (acts * total_decoder_norm[None, :]).sum(-1).mean(0) 217 | 218 | l0_loss = (acts > 0).float().sum(-1).mean() 219 | 220 | return LossOutput(l2_loss=l2_loss, l1_loss=l1_loss, l0_loss=l0_loss) 221 | 222 | def create_save_dir(self): 223 | version_list = [ 224 | int(file.name.split("_")[1]) 225 | for file in list(SAVE_DIR.iterdir()) 226 | if "version" in str(file) 227 | ] 228 | if len(version_list): 229 | version = 1 + max(version_list) 230 | else: 231 | version = 0 232 | self.save_dir = SAVE_DIR / f"version_{version}" 233 | self.save_dir.mkdir(parents=True) 234 | 235 | def save(self): 236 | if self.save_dir is None: 237 | self.create_save_dir() 238 | weight_path = self.save_dir / f"{self.save_version}.pt" 239 | cfg_path = self.save_dir / f"{self.save_version}_cfg.json" 240 | 241 | torch.save(self.state_dict(), weight_path) 242 | with open(cfg_path, "w") as f: 243 | json.dump(cfg, f) 244 | 245 | print(f"Saved as version {self.save_version} in {self.save_dir}") 246 | self.save_version += 1 247 | 248 | @classmethod 249 | def load( 250 | cls, 251 | name, 252 | model=None, 253 | path="", 254 | ): 255 | # If the files are not in the default save directory, you can specify a path 256 | # It's assumed that weights are [name].pt and cfg is [name]_cfg.json 257 | if path == "": 258 | save_dir = SAVE_DIR 259 | else: 260 | save_dir = Path(path) 261 | cfg_path = save_dir / f"{str(name)}_cfg.json" 262 | weight_path = save_dir / f"{str(name)}.pt" 263 | 264 | cfg = json.load(open(cfg_path, "r")) 265 | pprint.pprint(cfg) 266 | if model is None: 267 | model = ( 268 | HookedTransformer.from_pretrained(cfg["model_name"]) 269 | .to(DTYPES[cfg["enc_dtype"]]) 270 | .to(cfg["device"]) 271 | ) 272 | self = cls(cfg=cfg, model=model) 273 | self.load_state_dict(torch.load(weight_path)) 274 | return self 275 | 276 | @classmethod 277 | def load_from_hf(cls, version): 278 | """ 279 | Loads the saved autoencoder from HuggingFace. 280 | 281 | Version is expected to be an int, or "run1" or "run2" 282 | 283 | version 25 is the final checkpoint of the first autoencoder run, 284 | version 47 is the final checkpoint of the second autoencoder run. 285 | """ 286 | if version == "run1": 287 | version = 25 288 | elif version == "run2": 289 | version = 47 290 | 291 | cfg = utils.download_file_from_hf( 292 | "NeelNanda/sparse_autoencoder", f"{version}_cfg.json" 293 | ) 294 | pprint.pprint(cfg) 295 | self = cls(cfg=cfg) 296 | self.load_state_dict( 297 | utils.download_file_from_hf( 298 | "NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True 299 | ) 300 | ) 301 | return self 302 | 303 | 304 | # %% 305 | def get_stacked_resids(model, tokens, drop_bos=True): 306 | """ 307 | Returns the stacked activations of the resid_post hook for all layers. 308 | 309 | This could be made about 2x more memory efficient with a buffer, but who cares. 310 | 311 | If drop_bos is true, the resids on the BOS (first) token is dropped. 312 | 313 | Returns stacked_resids: [batch, seq_len (- 1), n_layers, d_model] 314 | """ 315 | _, cache = model.run_with_cache( 316 | tokens, names_filter=lambda x: x.endswith("resid_post") 317 | ) 318 | stacked_resids = torch.stack( 319 | [cache["resid_post", i] for i in range(model.cfg.n_layers)], dim=-2 320 | ) 321 | if drop_bos: 322 | stacked_resids = stacked_resids[:, 1:, :, :] 323 | return stacked_resids 324 | 325 | 326 | # %% 327 | def shuffle_data(all_tokens): 328 | print("Shuffling data") 329 | return all_tokens[torch.randperm(all_tokens.shape[0])] 330 | 331 | 332 | def push_to_hub(local_dir): 333 | if isinstance(local_dir, huggingface_hub.Repository): 334 | local_dir = local_dir.local_dir 335 | os.system(f"git -C {local_dir} add .") 336 | os.system(f"git -C {local_dir} commit -m 'Auto Commit'") 337 | os.system(f"git -C {local_dir} push") 338 | 339 | 340 | def upload_folder_to_hf(folder_path, repo_name=None, debug=False): 341 | """ 342 | Uploads a folder to HuggingFace, and creates a repo for it. 343 | """ 344 | folder_path = Path(folder_path) 345 | if repo_name is None: 346 | repo_name = folder_path.name 347 | repo_folder = folder_path.parent / (folder_path.name + "_repo") 348 | repo_url = huggingface_hub.create_repo(repo_name, exist_ok=True) 349 | repo = huggingface_hub.Repository(repo_folder, repo_url) 350 | 351 | for file in folder_path.iterdir(): 352 | if debug: 353 | print(file.name) 354 | file.rename(repo_folder / file.name) 355 | push_to_hub(repo.local_dir) 356 | 357 | 358 | # loading_data_first_time = False 359 | # if loading_data_first_time: 360 | # raise NotImplementedError("This is not implemented yet") 361 | # data = load_dataset( 362 | # "NeelNanda/c4-code-tokenized-2b", split="train", cache_dir="/workspace/cache/" 363 | # ) 364 | # data.save_to_disk("/workspace/data/c4_code_tokenized_2b.hf") 365 | # data.set_format(type="torch", columns=["tokens"]) 366 | # all_tokens = data["tokens"] 367 | # all_tokens.shape 368 | 369 | # all_tokens_reshaped = einops.rearrange( 370 | # all_tokens, "batch (x seq_len) -> (batch x) seq_len", x=8, seq_len=128 371 | # ) 372 | # all_tokens_reshaped[:, 0] = model.tokenizer.bos_token_id 373 | # all_tokens_reshaped = all_tokens_reshaped[ 374 | # torch.randperm(all_tokens_reshaped.shape[0]) 375 | # ] 376 | # torch.save(all_tokens_reshaped, "/workspace/data/c4_code_2b_tokens_reshaped.pt") 377 | # else: 378 | # # data = datasets.load_from_disk("/workspace/data/c4_code_tokenized_2b.hf") 379 | # all_tokens = torch.load("/workspace/data/owt_tensor.pt") 380 | # # all_tokens = all_tokens[: cfg["num_tokens"] // cfg["seq_len"]] 381 | # # all_tokens = shuffle_data(all_tokens) 382 | 383 | 384 | # %% 385 | class Buffer: 386 | """ 387 | This defines a data buffer, to store a stack of acts across all layers that can be used to train the autoencoder. It'll automatically run the model to generate more when it gets halfway empty. 388 | """ 389 | 390 | def __init__(self, cfg, model): 391 | self.cfg = cfg 392 | self.buffer_size = cfg["batch_size"] * cfg["buffer_mult"] 393 | self.buffer_batches = self.buffer_size // (cfg["seq_len"] - 1) 394 | self.buffer_size = self.buffer_batches * (cfg["seq_len"] - 1) 395 | self.buffer = torch.zeros( 396 | (self.buffer_size, model.cfg.n_layers, model.cfg.d_model), 397 | dtype=torch.bfloat16, 398 | requires_grad=False, 399 | ).to(cfg["device"]) 400 | self.cfg = cfg 401 | self.model = model 402 | self.token_pointer = 0 403 | self.first = True 404 | self.normalize = True 405 | # average norm of residuals per layer / sqrt(d_model) 406 | # We divide by this to normalise the data. This is *not* mean centered. 407 | self.normalisation_factor = torch.tensor( 408 | [ 409 | 1.8281, 410 | 2.0781, 411 | 2.2031, 412 | 2.4062, 413 | 2.5781, 414 | 2.8281, 415 | 3.1562, 416 | 3.6875, 417 | 4.3125, 418 | 5.4062, 419 | 7.8750, 420 | 16.5000, 421 | ], 422 | device="cuda:0", 423 | dtype=torch.float32, 424 | ) 425 | # The factors when mean centering (ie using the standard deviation) 426 | # self.normalisation_factor = torch.tensor( 427 | # [ 428 | # 1.4248, 429 | # 1.5720, 430 | # 1.6795, 431 | # 1.8498, 432 | # 2.0202, 433 | # 2.2450, 434 | # 2.5181, 435 | # 2.9152, 436 | # 3.3975, 437 | # 4.1135, 438 | # 5.1676, 439 | # 8.3306, 440 | # ], 441 | # device=cfg["device"], 442 | # ) 443 | self.refresh() 444 | 445 | @torch.no_grad() 446 | def refresh(self): 447 | self.pointer = 0 448 | print("Refreshing the buffer!") 449 | with torch.autocast("cuda", torch.bfloat16): 450 | if self.first: 451 | num_batches = self.buffer_batches 452 | else: 453 | num_batches = self.buffer_batches // 2 454 | self.first = False 455 | for _ in tqdm.trange(0, num_batches, self.cfg["model_batch_size"]): 456 | tokens = all_tokens[ 457 | self.token_pointer : min( 458 | self.token_pointer + self.cfg["model_batch_size"], num_batches 459 | ) 460 | ] 461 | _, cache = self.model.run_with_cache( 462 | tokens, names_filter=lambda x: x.endswith("resid_post") 463 | ) 464 | cache: ActivationCache 465 | 466 | acts = cache.stack_activation("resid_post") 467 | acts = acts[:, :, 1:, :] # Drop BOS 468 | acts = einops.rearrange( 469 | acts, 470 | "n_layers batch seq_len d_model -> (batch seq_len) n_layers d_model", 471 | ) 472 | 473 | # print(tokens.shape, acts.shape, self.pointer, self.token_pointer) 474 | self.buffer[self.pointer : self.pointer + acts.shape[0]] = acts 475 | self.pointer += acts.shape[0] 476 | self.token_pointer += self.cfg["model_batch_size"] 477 | # if self.token_pointer > all_tokens.shape[0] - self.cfg["model_batch_size"]: 478 | # self.token_pointer = 0 479 | 480 | self.pointer = 0 481 | self.buffer = self.buffer[ 482 | torch.randperm(self.buffer.shape[0]).to(self.cfg["device"]) 483 | ] 484 | 485 | @torch.no_grad() 486 | def next(self): 487 | out = self.buffer[self.pointer : self.pointer + self.cfg["batch_size"]].float() 488 | # out: [batch_size, n_layers, d_model] 489 | self.pointer += self.cfg["batch_size"] 490 | if self.pointer > self.buffer.shape[0] // 2 - self.cfg["batch_size"]: 491 | # print("Refreshing the buffer!") 492 | self.refresh() 493 | if self.normalize: 494 | # Make each layer's vector have expected stdev sqrt(d_model). 495 | # I use stdev not norm because the mean is easy to learn. 496 | out = out / self.normalisation_factor[None, :, None] 497 | return out 498 | 499 | 500 | class Trainer: 501 | def __init__(self, cfg, model, use_wandb=True): 502 | self.cfg = cfg 503 | self.model = model 504 | self.crosscoder = CrossCoder(cfg, model) 505 | self.buffer = Buffer(cfg, model) 506 | self.total_steps = cfg["num_tokens"] // cfg["batch_size"] 507 | 508 | self.optimizer = torch.optim.Adam( 509 | self.crosscoder.parameters(), 510 | lr=cfg["lr"], 511 | betas=(cfg["beta1"], cfg["beta2"]), 512 | ) 513 | self.scheduler = torch.optim.lr_scheduler.LambdaLR( 514 | self.optimizer, self.lr_lambda 515 | ) 516 | self.step_counter = 0 517 | 518 | if use_wandb: 519 | wandb.init(project="crosscoder", entity="neelnanda-io", config=cfg) 520 | 521 | def lr_lambda(self, step): 522 | if step < 0.05 * self.total_steps: 523 | return step / (0.05 * self.total_steps) 524 | elif step < 0.8 * self.total_steps: 525 | return 1.0 526 | else: 527 | return 1.0 - (step - 0.8 * self.total_steps) / (0.2 * self.total_steps) 528 | 529 | def get_l1_coeff(self): 530 | # Linearly increases from 0 to cfg["l1_coeff"] over the first 0.05 * self.total_steps steps, then keeps it constant 531 | if self.step_counter < 0.05 * self.total_steps: 532 | return self.cfg["l1_coeff"] * self.step_counter / (0.05 * self.total_steps) 533 | else: 534 | return self.cfg["l1_coeff"] 535 | 536 | def step(self): 537 | acts = self.buffer.next() 538 | losses = self.crosscoder.get_losses(acts) 539 | loss = losses.l2_loss + self.get_l1_coeff() * losses.l1_loss 540 | loss.backward() 541 | clip_grad_norm_(self.crosscoder.parameters(), max_norm=1.0) 542 | self.optimizer.step() 543 | self.scheduler.step() 544 | self.optimizer.zero_grad() 545 | 546 | loss_dict = { 547 | "loss": loss.item(), 548 | "l2_loss": losses.l2_loss.item(), 549 | "l1_loss": losses.l1_loss.item(), 550 | "l0_loss": losses.l0_loss.item(), 551 | "l1_coeff": self.get_l1_coeff(), 552 | "lr": self.scheduler.get_last_lr()[0], 553 | } 554 | self.step_counter += 1 555 | return loss_dict 556 | 557 | def log(self, loss_dict): 558 | wandb.log(loss_dict, step=self.step_counter) 559 | print(loss_dict) 560 | 561 | def save(self): 562 | self.crosscoder.save() 563 | 564 | def train(self): 565 | self.step_counter = 0 566 | try: 567 | for i in tqdm.trange(self.total_steps): 568 | loss_dict = self.step() 569 | if i % self.cfg["log_every"] == 0: 570 | self.log(loss_dict) 571 | if (i + 1) % self.cfg["save_every"] == 0: 572 | self.save() 573 | finally: 574 | self.save() 575 | 576 | 577 | # buffer.refresh() 578 | # %% 579 | 580 | 581 | # %% 582 | def replacement_hook(mlp_post, hook, encoder): 583 | mlp_post_reconstr = encoder(mlp_post)[1] 584 | return mlp_post_reconstr 585 | 586 | 587 | def mean_ablate_hook(mlp_post, hook): 588 | mlp_post[:] = mlp_post.mean([0, 1]) 589 | return mlp_post 590 | 591 | 592 | def zero_ablate_hook(mlp_post, hook): 593 | mlp_post[:] = 0.0 594 | return mlp_post 595 | 596 | 597 | # %% 598 | # Frequency 599 | @torch.no_grad() 600 | def get_freqs(num_batches=25, local_encoder=None): 601 | raise NotImplementedError("This is not implemented yet") 602 | if local_encoder is None: 603 | local_encoder = encoder 604 | act_freq_scores = torch.zeros(local_encoder.d_hidden, dtype=torch.float32).to( 605 | cfg["device"] 606 | ) 607 | total = 0 608 | for i in tqdm.trange(num_batches): 609 | tokens = all_tokens[torch.randperm(len(all_tokens))[: cfg["model_batch_size"]]] 610 | 611 | _, cache = model.run_with_cache( 612 | tokens, stop_at_layer=cfg["layer"] + 1, names_filter=cfg["act_name"] 613 | ) 614 | acts = cache[cfg["act_name"]] 615 | acts = acts.reshape(-1, cfg["act_size"]) 616 | 617 | hidden = local_encoder(acts)[2] 618 | 619 | act_freq_scores += (hidden > 0).sum(0) 620 | total += hidden.shape[0] 621 | act_freq_scores /= total 622 | num_dead = (act_freq_scores == 0).float().mean() 623 | print("Num dead", num_dead) 624 | return act_freq_scores 625 | -------------------------------------------------------------------------------- /rare_freq_dir.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neelnanda-io/Crosscoders/3adc7eb23a5a56f12557d2c6c206f0aa688bdbd3/rare_freq_dir.pt -------------------------------------------------------------------------------- /scratch.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from utils import * 3 | print(model.cfg.n_layers) 4 | print(cfg) 5 | scratch_cfg = { 6 | "l1_coeff": 5, 7 | "dec_init_norm": 0.01, 8 | "log_every": 5, 9 | "seed": 50, 10 | "num_tokens": int(1e7), 11 | } 12 | cfg.update(scratch_cfg) 13 | train = Trainer(cfg, model, use_wandb=False) 14 | # train.total_steps = 10 15 | # train.train() 16 | # %% 17 | cc = CrossCoder.load("version_3", 17, train.model) 18 | buffer = train.buffer 19 | print(buffer.buffer.norm(dim=-1).mean(0)) 20 | # %% 21 | # ave_norms = buffer.buffer.norm(dim=-1).mean(0) / np.sqrt(model.cfg.d_model) 22 | ave_norms = torch.tensor( 23 | [ 24 | 1.8281, 25 | 2.0781, 26 | 2.2031, 27 | 2.4062, 28 | 2.5781, 29 | 2.8281, 30 | 3.1562, 31 | 3.6875, 32 | 4.3125, 33 | 5.4062, 34 | 7.8750, 35 | 16.5000, 36 | ], 37 | device="cuda:0", 38 | dtype=torch.float32, 39 | ) 40 | # Mean centered normalisation 41 | mean_norms = torch.tensor([1.4248, 1.5720, 1.6795, 1.8498, 2.0202, 2.2450, 2.5181, 2.9152, 3.3975, 4.1135, 5.1676, 8.3306], device="cuda:0", dtype=torch.float32) 42 | acts2 = buffer.next() 43 | acts = acts2 / ave_norms[:, None] * mean_norms[:, None] 44 | print(acts.shape) 45 | print(acts.norm(dim=-1).mean(0)) 46 | print((buffer.buffer[:200] / ave_norms[None, :, None]).norm(dim=-1).mean(0)) 47 | # %% 48 | means = buffer.buffer.mean(0) / ave_norms[:, None] 49 | means2 = buffer.buffer.mean(0) / mean_norms[:, None] 50 | line(means) 51 | # %% 52 | cc2 = CrossCoder.load("version_12", 1, train.model) 53 | # %% 54 | torch.set_grad_enabled(False) 55 | 56 | recons_acts2 = cc2(acts2.float()) 57 | recons_acts = cc(acts.float()) 58 | variance = (acts.float() - means.float()).pow(2).sum(-1).mean(0) 59 | variance2 = (acts2.float() - means2.float()).pow(2).sum(-1).mean(0) 60 | mse = (acts.float() - recons_acts).pow(2).sum(-1).mean(0) 61 | mse2 = (acts2.float() - recons_acts2).pow(2).sum(-1).mean(0) 62 | print(mse) 63 | print(mse2) 64 | print(variance) 65 | print(mse.sum() / variance.sum()) 66 | line([mse / variance, mse2 / variance], line_labels=["lambda=2", "lambda=5"], title="FVU") 67 | 68 | 69 | # %% 70 | cc_acts = cc.encode(acts.float()) 71 | (cc_acts > 0).float().sum(-1).mean(0) 72 | # %% 73 | model = train.model 74 | batch_size = 128 75 | seq_len = 64 76 | tokens = all_tokens[:batch_size, :seq_len] 77 | with torch.autocast("cuda", torch.bfloat16): 78 | _, cache = model.run_with_cache(tokens, names_filter=lambda x: x.endswith("resid_post")) 79 | resids = cache.stack_activation("resid_post") 80 | resids = einops.rearrange(resids, "layer batch seq d_model -> (batch seq) layer d_model") 81 | 82 | print(resids.shape) 83 | recons_resids = cc(resids / ave_norms[:, None, None]) * ave_norms[None, :, None] 84 | recons_resids2 = cc2(resids / mean_norms[:, None, None]) * mean_norms[None, :, None] 85 | def replace_resids_hook(resid, hook, layer, lambd=2): 86 | # flat_resid = resid.reshape(-1, model.cfg.d_model) 87 | # if lambd == 2: 88 | # rec = recons_resids[:, layer] 89 | # else: 90 | # rec = recons_resids2[:, layer] 91 | # diff = flat_resid - rec 92 | # print(layer, lambd) 93 | # print(diff.pow(2).sum(-1).mean(0) / flat_resid.pow(2).sum(-1).mean(0)) 94 | # print(diff.pow(2).sum(-1).mean(0), flat_resid.pow(2).sum(-1).mean(0), recons_resids.pow(2).sum(-1).mean(0)) 95 | 96 | new_resid = torch.zeros_like(resid) 97 | new_resid[:, 0, :] = resid[:, 0, :] 98 | if lambd == 2: 99 | new_resid[:, 1:, :] = recons_resids[:, layer].reshape(batch_size, seq_len, model.cfg.d_model)[:, 1:, :] 100 | else: 101 | new_resid[:, 1:, :] = recons_resids2[:, layer].reshape(batch_size, seq_len, model.cfg.d_model)[:, 1:, :] 102 | return new_resid 103 | losses = [] 104 | orig_loss = model(tokens, return_type="loss") 105 | print(orig_loss) 106 | for lambd in [2, "2_64K"]: 107 | for layer in tqdm.trange(model.cfg.n_layers): 108 | loss = model.run_with_hooks(tokens, fwd_hooks=[(f"blocks.{layer}.hook_resid_post", partial(replace_resids_hook, layer=layer, lambd=lambd))], return_type="loss") 109 | losses.append({"loss": loss, "lambd": lambd, "layer": layer, "ce_delta": loss - orig_loss}) 110 | loss_df = pd.DataFrame(losses) 111 | px.line(loss_df, x="layer", y="ce_delta", color="lambd", title="CE Delta for each layer").show() 112 | print(loss_df.groupby("lambd")["ce_delta"].mean()) 113 | 114 | 115 | # line(losses) 116 | # %% 117 | import sae_lens 118 | import yaml 119 | # with open("/workspace/SAELens/sae_lens/pretrained_saes.yaml", "r") as file: 120 | # pretrained_saes = yaml.safe_load(file) 121 | # print(pretrained_saes.keys()) 122 | RELEASE = "gpt2-small-resid-post-v5-32k" 123 | saes = [] 124 | for i in range(model.cfg.n_layers): 125 | sae = sae_lens.SAE.from_pretrained( 126 | release=RELEASE, 127 | sae_id=f"blocks.{i}.hook_resid_post", 128 | device="cuda", 129 | )[0] 130 | saes.append(sae) 131 | # saes.append( 132 | # sae_lens.SAE.from_pretrained( 133 | # release=RELEASE, 134 | # sae_id=f"blocks.11.hook_resid_post", 135 | # device="cuda", 136 | # ) 137 | # ) 138 | # %% 139 | def LN( 140 | x: torch.Tensor, eps: float = 1e-5 141 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 142 | mu = x.mean(dim=-1, keepdim=True) 143 | x = x - mu 144 | std = x.std(dim=-1, keepdim=True) 145 | x = x / (std + eps) 146 | return x, mu, std 147 | 148 | 149 | def replace_resids_sae_hook(resid, hook, layer): 150 | normed_resid, mu, std = LN(resid[:, 1:, :]) 151 | recons_resid = saes[layer](normed_resid) * std + mu 152 | resid[:, 1:, :] = recons_resid 153 | return resid 154 | 155 | 156 | # losses = [] 157 | # orig_loss = model(tokens, return_type="loss") 158 | # print(orig_loss) 159 | 160 | for layer in tqdm.trange(model.cfg.n_layers): 161 | loss = model.run_with_hooks( 162 | tokens, 163 | fwd_hooks=[ 164 | ( 165 | f"blocks.{layer}.hook_resid_post", 166 | partial(replace_resids_sae_hook, layer=layer), 167 | ) 168 | ], 169 | return_type="loss", 170 | ) 171 | losses.append( 172 | {"loss": loss, "lambd": -1, "layer": layer, "ce_delta": loss - orig_loss} 173 | ) 174 | loss_df = pd.DataFrame(losses) 175 | px.line( 176 | loss_df, x="layer", y="ce_delta", color="lambd", title="CE Delta for each layer" 177 | ).show() 178 | print(loss_df.groupby("lambd")["ce_delta"].mean()) 179 | # %% 180 | l0s = [] 181 | for i in range(model.cfg.n_layers): 182 | acts = saes[i].encode(LN(cache["resid_post", i][:, 1:, :])[0]) 183 | l0s.append((acts > 0).sum(-1).float().mean()) 184 | fig = line(l0s, return_fig=True, title="L0s for SAEs vs Crosscoders") 185 | filt_resids = resids.reshape(batch_size, seq_len, model.cfg.n_layers, model.cfg.d_model)[:, 1:, :, :] 186 | filt_resids = einops.rearrange(filt_resids, "batch seq layer d_model -> (batch seq) layer d_model") 187 | lambda2l0 = (cc.encode(filt_resids)>0).float().sum(-1).mean() 188 | lambda5l0 = (cc2.encode(filt_resids)>0).float().sum(-1).mean() 189 | fig.add_hline(y=lambda2l0, line_dash="dash", line_color="red", annotation_text="Crosscoder lambda=2") 190 | fig.add_hline(y=lambda5l0, line_dash="dash", line_color="green", annotation_text="Crosscoder lambda=5") 191 | fig.show() 192 | # width = 65 193 | # layer = 18 194 | # %% 195 | loss_df["L0"] = [lambda2l0.item()]*12 + [lambda5l0.item()]*12 + [i.item() for i in l0s] 196 | px.scatter(loss_df, x="L0", y="ce_delta", color="lambd", title="CE Delta vs L0", color_continuous_scale="Portland") 197 | 198 | # %% 199 | acts = buffer.next() 200 | recons_acts = cc(acts) 201 | recons_acts2 = cc2(acts) 202 | recons_acts_sae = torch.zeros_like(acts) 203 | for i in range(model.cfg.n_layers): 204 | normed_acts, mu, std = LN(acts[:, i, :]) 205 | recons_acts_sae[:, i, :] = (saes[i](normed_acts)) * std + mu 206 | means = acts.mean(0) 207 | variance = (acts - means).pow(2).sum(-1).mean(0) 208 | mse = (acts - recons_acts).pow(2).sum(-1).mean(0) 209 | mse2 = (acts - recons_acts2).pow(2).sum(-1).mean(0) 210 | mse_sae = (acts - recons_acts_sae).pow(2).sum(-1).mean(0) 211 | print(mse) 212 | print(mse2) 213 | print(mse_sae) 214 | print(variance) 215 | print(mse.sum() / variance.sum()) 216 | line( 217 | [mse / variance, mse2 / variance, mse_sae / variance], 218 | line_labels=["lambda=2", "lambda=5", "SAE"], 219 | title="FVU" 220 | ) 221 | loss_df["mse"] = to_numpy(torch.cat([mse, mse2, mse_sae])) 222 | loss_df["fvu"] = to_numpy(torch.cat([mse/variance, mse2/variance, mse_sae/variance])) 223 | # %% 224 | loss_df["is_sae"] = loss_df["lambd"] == -1 225 | # Create a mapping for lambda values to marker symbols 226 | lambda_symbols = {2: 'circle', 5: 'square', -1: 'diamond'} 227 | 228 | # Create the scatter plot 229 | fig = px.scatter( 230 | loss_df, 231 | x="L0", 232 | y="ce_delta", 233 | facet_col="layer", 234 | facet_col_wrap=3, 235 | symbol="lambd", 236 | symbol_map=lambda_symbols, 237 | title="CE Delta vs L0", 238 | color="is_sae", 239 | labels={"L0": "L0", "ce_delta": "CE Delta", "layers": "Layer", "lambd": "Lambda"}, 240 | color_continuous_scale="Portland", 241 | height=1000 242 | 243 | ) 244 | fig.show() 245 | # %% 246 | x = all_tokens[:256, :64] 247 | _, cache = model.run_with_cache(x, names_filter=lambda x: x.endswith("resid_post"), return_type=None) 248 | # layer, batch, seq, d_model 249 | resids = cache.stack_activation("resid_post")[:, :, 1:, :] 250 | recons_acts_sae = torch.zeros_like(resids) 251 | for i in range(model.cfg.n_layers): 252 | normed_resids, mu, std = LN(resids[i, :, :]) 253 | recons_acts_sae[i, :, :] = (saes[i](normed_resids)) * std + mu 254 | normed_resids_cc = resids / ave_norms[:, None, None, None] 255 | normed_resids_cc = einops.rearrange(normed_resids_cc, "layer batch seq d_model -> (batch seq) layer d_model") 256 | recons_acts_cc = cc(normed_resids_cc) 257 | recons_acts_cc = einops.rearrange(recons_acts_cc * ave_norms[None, :, None], "(batch seq) layer d_model -> layer batch seq d_model", batch=x.shape[0]) 258 | recons_acts_cc2 = cc2(normed_resids_cc) 259 | recons_acts_cc2 = einops.rearrange(recons_acts_cc2 * ave_norms[None, :, None], "(batch seq) layer d_model -> layer batch seq d_model", batch=x.shape[0]) 260 | mean_resids = resids.mean([1, 2], keepdim=True) 261 | variance = ((resids - mean_resids) ** 2).sum(-1, keepdim=True).mean([1, 2], keepdim=True) 262 | fvu_sae = ((resids - recons_acts_sae).pow(2)/variance).sum(-1).mean(1) 263 | fvu_cc = ((resids - recons_acts_cc).pow(2)/variance).sum(-1).mean(1) 264 | fvu_cc2 = ((resids - recons_acts_cc2).pow(2)/variance).sum(-1).mean(1) 265 | line(fvu_sae, title="SAE") 266 | line(fvu_cc, title="lambda=2") 267 | line(fvu_cc2, title="lambda=5") 268 | # %% 269 | y = torch.stack([fvu_sae, fvu_cc, fvu_cc2]) 270 | line(y, line_labels=["SAE", "lambda=2", "lambda=5"]*20, title="FVU", facet_col=1, facet_col_wrap=3) 271 | 272 | line(y.mean(-1), line_labels=["SAE", "lambda=2", "lambda=5"]*20, title="FVU") 273 | 274 | # %% 275 | all_sae_acts = [] 276 | all_sae_acts2 = [] 277 | for _ in tqdm.trange(50): 278 | acts2 = buffer.next() 279 | acts = acts2 / ave_norms[:, None] * mean_norms[:, None] 280 | all_sae_acts.append(cc.encode(acts)) 281 | all_sae_acts2.append(cc2.encode(acts2)) 282 | all_sae_acts = torch.cat(all_sae_acts) 283 | all_sae_acts2 = torch.cat(all_sae_acts2) 284 | print(all_sae_acts.shape) 285 | print(all_sae_acts2.shape) 286 | freqs = (all_sae_acts > 0).float().mean(0) 287 | freqs2 = (all_sae_acts2 > 0).float().mean(0) 288 | histogram((freqs+1e-6).log10(), title="SAE") 289 | histogram((freqs2+1e-6).log10(), title="SAE2") 290 | # %% 291 | -------------------------------------------------------------------------------- /scratch_2.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import torch 3 | from utils import * 4 | from transformer_lens import * 5 | 6 | sd = torch.load("/workspace/Crosscoders/checkpoints/version_2/17.pt") 7 | model = HookedTransformer.from_pretrained("gpt2-small") 8 | 9 | normalisation_factor = torch.tensor( 10 | [ 11 | 1.8281, 12 | 2.0781, 13 | 2.2031, 14 | 2.4062, 15 | 2.5781, 16 | 2.8281, 17 | 3.1562, 18 | 3.6875, 19 | 4.3125, 20 | 5.4062, 21 | 7.8750, 22 | 16.5000, 23 | ], 24 | device="cuda:0", 25 | dtype=torch.float32, 26 | ) 27 | # %% 28 | 29 | 30 | def fold_scale_into_sd(sd, normalisation_factor): 31 | d = {} 32 | print(sd.keys()) 33 | d["W_enc"] = sd["W_enc"] / normalisation_factor[:, None, None] 34 | d["W_dec"] = sd["W_dec"] * normalisation_factor[None, :, None] 35 | d["b_enc"] = sd["b_enc"] 36 | d["b_dec"] = sd["b_dec"] * normalisation_factor[:, None] 37 | return d 38 | 39 | 40 | d = fold_scale_into_sd(sd, normalisation_factor) 41 | for k in sd.keys(): 42 | print(k, sd[k].shape, d[k].shape) 43 | 44 | # %% 45 | cfg = json.load(open("/workspace/Crosscoders/checkpoints/version_2/17_cfg.json", "r")) 46 | cc = CrossCoder(cfg, model) 47 | cc_normed = CrossCoder(cfg, model) 48 | cc_normed.load_state_dict(d) 49 | # %% 50 | data = load_dataset("stas/openwebtext-10k") 51 | # %% 52 | s = data["train"][0]["text"] 53 | tokens = model.to_tokens(s) 54 | print(tokens.shape) 55 | # %% 56 | resids = get_stacked_resids(model, tokens, drop_bos=True) 57 | resids.shape 58 | # %% 59 | resids = resids.squeeze(0) 60 | recons = cc(resids) 61 | recons_normed = cc_normed(resids) 62 | print(resids.norm(), (resids - recons).norm(), (resids - recons_normed).norm()) 63 | # %% 64 | cfg = json.load(open("/workspace/Crosscoders/checkpoints/version_12/1_cfg.json", "r")) 65 | sd = torch.load("/workspace/Crosscoders/checkpoints/version_12/1.pt") 66 | sd_normed = fold_scale_into_sd(sd, normalisation_factor) 67 | name = "lambda_2_64k" 68 | json.dump(cfg, open(f"/workspace/Crosscoders/checkpoints/{name}_cfg.json", "w")) 69 | torch.save(sd_normed, f"/workspace/Crosscoders/checkpoints/{name}.pt") 70 | # %% 71 | cc_new = CrossCoder.load(name, model) 72 | print(cc_new.get_losses(resids)) 73 | print(cc.get_losses(resids)) 74 | print(cc_normed.get_losses(resids)) 75 | # %% 76 | import huggingface_hub 77 | 78 | upload_folder_to_hf(f"/workspace/Crosscoders/checkpoints/", "crosscoders-gpt2-small") 79 | 80 | # %% 81 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r", encoding="utf-8") as fh: 4 | long_description = fh.read() 5 | 6 | setup( 7 | name="crosscoders", 8 | version="0.1.0", 9 | author="Neel Nanda", 10 | author_email="neel@neelnanda.io", 11 | description="A package for training GPT-2 Small acausal crosscoders", 12 | long_description=long_description, 13 | long_description_content_type="markdown", 14 | url="https://github.com/neelnanda-io/CrossCoders", 15 | packages=find_packages(), 16 | classifiers=[ 17 | "Development Status :: 3 - Alpha", 18 | "Intended Audience :: Science/Research", 19 | "License :: OSI Approved :: MIT License", 20 | "Programming Language :: Python :: 3", 21 | "Programming Language :: Python :: 3.8", 22 | "Programming Language :: Python :: 3.9", 23 | "Programming Language :: Python :: 3.10", 24 | ], 25 | python_requires=">=3.8", 26 | install_requires=[ 27 | "torch>=1.12.0", 28 | "transformer_lens", 29 | "wandb", 30 | "einops", 31 | "tqdm", 32 | "numpy", 33 | "huggingface_hub", 34 | ], 35 | ) --------------------------------------------------------------------------------