├── .gitignore ├── run.sh ├── README.md ├── test_sae.py ├── test_llama.py └── sae.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | activation_cache 3 | wandb 4 | tmp 5 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python sae.py \ 2 | --model-name meta-llama/Llama-2-7b-hf \ 3 | --layer-idx 8 \ 4 | --cache-dir activation_cache \ 5 | --learning-rate 1e-4 \ 6 | --batch-size 2048 \ 7 | --wandb-project test_sae_project \ 8 | --num-train-samples 1000 \ 9 | --num-val-samples 100 \ 10 | --wandb-run-name test_sae_run_relu \ 11 | --sae-type topk \ 12 | --l1-coef 0.01 \ 13 | --num-epochs 10 14 | 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Minimal SAE Trainer 2 | 3 | A principled implementation for training Sparse Autoencoders (SAEs) designed to extract interpretable features from Large Language Model (LLM) activations. This repository prioritizes implementation clarity and computational efficiency. 4 | 5 | ## Overview 6 | 7 | - Efficient activation caching with sharded storage 8 | - Rigorous validation set evaluation for Reconstruction loss. 9 | - W&B Logging 10 | - FineWeb dataset 11 | 12 | The implementation is particularly suited for researchers studying mechanistic interpretability who require a robust yet maintainable codebase. 13 | 14 | ## Dependencies 15 | 16 | ```bash 17 | pip install torch transformers wandb safetensors datasets click tqdm 18 | ``` 19 | 20 | ## Usage 21 | 22 | ### Basic Training 23 | Execute training with default parameters: 24 | 25 | ```bash 26 | export HF_HF_HUB_ENABLE_HF_TRANSFER=1 # Enable HF transfer for faster fineweb download. 27 | 28 | python sae.py \ 29 | --model-name meta-llama/Llama-2-7b-hf \ 30 | --layer-idx 8 \ 31 | --cache-dir activation_cache \ 32 | --d-hidden 2048 \ 33 | --learning-rate 1e-3 \ 34 | --batch-size 2 \ 35 | --wandb-project test_sae_project \ 36 | --num-train-samples 10000 \ 37 | --num-val-samples 100 \ 38 | --overwrite-cache 39 | ``` 40 | 41 | ### Choosing the Sparse Autoencoder Type 42 | 43 | You can choose between two types of sparse autoencoders: ReLU and TopK. Use the `--sae-type` option to specify which one to use: 44 | 45 | - **ReLU Autoencoder**: This is the default option. It uses a ReLU activation function and an L1 regularization term to encourage sparsity. 46 | 47 | ```bash 48 | python sae.py --sae-type relu 49 | ``` 50 | 51 | - **TopK Autoencoder**: This autoencoder selects the top K activations, setting the rest to zero, which can be specified with the `--topk` option. 52 | 53 | ```bash 54 | python sae.py --sae-type topk --topk 100 55 | ``` 56 | 57 | ## Citation 58 | 59 | ```bibtex 60 | @software{minSAE, 61 | author = {Simo Ryu}, 62 | title = {Minimal SAE Trainer}, 63 | year = {2024}, 64 | publisher = {GitHub}, 65 | url = {https://github.com/cloneofsimo/minSAE} 66 | } 67 | ``` -------------------------------------------------------------------------------- /test_sae.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from safetensors.torch import save_file 5 | 6 | from sae import ActivationLoader, SparseAutoencoder 7 | 8 | 9 | def test_sae(): 10 | """Test basic SAE functionality""" 11 | print("Testing SAE...") 12 | 13 | # Initialize SAE 14 | d_input = 768 15 | d_hidden = 1024 16 | batch_size = 32 17 | sae = SparseAutoencoder(d_input, d_hidden) 18 | 19 | # Test forward pass dimensions 20 | batch = torch.randn(batch_size, d_input) 21 | decoded, encoded = sae(batch) 22 | assert decoded.shape == ( 23 | batch_size, 24 | d_input, 25 | ), f"Wrong decode shape: {decoded.shape}" 26 | assert encoded.shape == ( 27 | batch_size, 28 | d_hidden, 29 | ), f"Wrong encode shape: {encoded.shape}" 30 | 31 | # Test sparsity 32 | assert torch.any(encoded == 0), "No zero activations found" 33 | sparsity = (encoded == 0).float().mean() 34 | assert sparsity > 0.1, f"Not sparse enough: {sparsity}" 35 | 36 | # Test gradients 37 | sae.zero_grad() 38 | loss = torch.nn.functional.mse_loss(decoded, batch) 39 | loss.backward() 40 | 41 | # Check gradients exist and are non-zero 42 | for name, param in sae.named_parameters(): 43 | assert param.grad is not None, f"No gradient for {name}" 44 | assert not torch.all(param.grad == 0), f"Zero gradient for {name}" 45 | 46 | print("SAE tests passed!") 47 | 48 | 49 | def test_data_loading(): 50 | """Test activation loading functionality""" 51 | print("Testing data loading...") 52 | 53 | # Create temporary directory 54 | tmp_dir = "./tmp" 55 | 56 | # Create dummy data 57 | d_input = 768 58 | n_samples = 256 59 | data = torch.randn(n_samples, d_input) 60 | 61 | # Save dummy shard 62 | os.makedirs(os.path.join(tmp_dir, "train"), exist_ok=True) 63 | save_file( 64 | {"activations": data}, os.path.join(tmp_dir, "train", "shard_00000.safetensors") 65 | ) 66 | 67 | # Save metadata 68 | with open(os.path.join(tmp_dir, "train", "metadata.json"), "w") as f: 69 | f.write('{"activation_dim": 768}') 70 | 71 | # Test loading without shuffle 72 | loader = ActivationLoader(tmp_dir, "train", batch_size=32, shuffle=False) 73 | loaded_data = torch.cat([batch for batch in loader]) 74 | assert torch.allclose(loaded_data, data), "Data loading mismatch" 75 | 76 | # Test loading with shuffle 77 | loader = ActivationLoader(tmp_dir, "train", batch_size=32, shuffle=True) 78 | shuffled_data = torch.cat([batch for batch in loader]) 79 | assert not torch.allclose(shuffled_data, data), "Data wasn't shuffled" 80 | 81 | # But should contain same values 82 | assert torch.allclose( 83 | torch.sort(shuffled_data.flatten())[0], torch.sort(data.flatten())[0] 84 | ), "Shuffled data values don't match" 85 | 86 | print("Data loading tests passed!") 87 | 88 | 89 | def test_training(): 90 | """Test basic training loop""" 91 | print("Testing training loop...") 92 | 93 | # Create dummy data and model 94 | d_input = 768 95 | d_hidden = 1024 96 | n_samples = 1000 97 | data = torch.randn(n_samples, d_input) 98 | sae = SparseAutoencoder(d_input, d_hidden) 99 | 100 | # Get initial loss 101 | with torch.no_grad(): 102 | decoded, _ = sae(data) 103 | initial_loss = torch.nn.functional.mse_loss(decoded, data).item() 104 | 105 | # Train for a few steps 106 | optimizer = torch.optim.Adam(sae.parameters()) 107 | for _ in range(10): 108 | optimizer.zero_grad() 109 | decoded, encoded = sae(data) 110 | loss = torch.nn.functional.mse_loss(decoded, data) 111 | loss.backward() 112 | optimizer.step() 113 | 114 | # Check final loss 115 | with torch.no_grad(): 116 | decoded, _ = sae(data) 117 | final_loss = torch.nn.functional.mse_loss(decoded, data).item() 118 | 119 | assert ( 120 | final_loss < initial_loss 121 | ), f"Loss didn't improve: {initial_loss} -> {final_loss}" 122 | print("Training tests passed!") 123 | 124 | 125 | if __name__ == "__main__": 126 | print("\nRunning SAE tests...") 127 | test_sae() 128 | test_data_loading() 129 | test_training() 130 | print("\nAll tests passed! 🎉\n") 131 | -------------------------------------------------------------------------------- /test_llama.py: -------------------------------------------------------------------------------- 1 | import json 2 | import shutil 3 | import tempfile 4 | from pathlib import Path 5 | 6 | import torch 7 | from safetensors.torch import load_file, save_file 8 | from tqdm import tqdm 9 | from transformers import AutoModelForCausalLM, AutoTokenizer 10 | 11 | 12 | def test_llama_hooks(): 13 | """Test that we can properly hook LLama activations""" 14 | print("Testing LLama hooks...") 15 | 16 | # Load small model for testing 17 | model_name = "meta-llama/Llama-2-7b-hf" # Change to a smaller model for testing 18 | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) 19 | tokenizer = AutoTokenizer.from_pretrained(model_name) 20 | 21 | # Track activations 22 | activations = [] 23 | 24 | def hook_fn(module, input, output): 25 | # Store MLP output activations 26 | activations.append(output.detach().cpu()) 27 | 28 | # Register hook on MLP output 29 | layer_idx = 8 30 | hook_found = False 31 | for name, module in model.named_modules(): 32 | if f"layers.{layer_idx}.mlp.down_proj" in name: 33 | module.register_forward_hook(hook_fn) 34 | hook_found = True 35 | break 36 | 37 | assert hook_found, f"Could not find layer {layer_idx} MLP" 38 | 39 | # Run a forward pass 40 | input_text = "The quick brown fox jumps over the lazy dog." 41 | inputs = tokenizer(input_text, return_tensors="pt") 42 | with torch.no_grad(): 43 | model(**inputs) 44 | 45 | print(f"Activations: {activations}") 46 | # Check activations were collected 47 | assert len(activations) > 0, "No activations collected" 48 | assert activations[0].ndim == 3, f"Wrong activation shape: {activations[0].shape}" 49 | print("LLama hook test passed!") 50 | 51 | 52 | def test_activation_caching(): 53 | """Test activation caching functionality with real LLama activations""" 54 | print("Testing activation caching...") 55 | 56 | # Create temporary cache directory 57 | tmp_dir = tempfile.mkdtemp() 58 | try: 59 | # Load model and tokenizer 60 | model_name = "meta-llama/Llama-2-7b-hf" 61 | model = AutoModelForCausalLM.from_pretrained( 62 | model_name, torch_dtype=torch.float16 63 | ) 64 | tokenizer = AutoTokenizer.from_pretrained(model_name) 65 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 66 | model.to(device) 67 | 68 | # Sample texts 69 | texts = [ 70 | "The quick brown fox jumps over the lazy dog.", 71 | "A journey of a thousand miles begins with a single step.", 72 | "To be or not to be, that is the question.", 73 | ] 74 | 75 | # Track activations 76 | activations = [] 77 | 78 | def hook_fn(module, input, output): 79 | activations.append(output.detach().cpu()) 80 | 81 | # Register hook on the MLP output 82 | layer_idx = 8 83 | for name, module in model.named_modules(): 84 | if f"layers.{layer_idx}.mlp.down_proj" in name: 85 | module.register_forward_hook(hook_fn) 86 | break 87 | 88 | # Generate activations 89 | with torch.no_grad(): 90 | for text in tqdm(texts, desc="Collecting activations"): 91 | inputs = tokenizer(text, return_tensors="pt").to(device) 92 | model(**inputs) 93 | 94 | # Save activations in shards 95 | shard_dir = Path(tmp_dir) / "train" 96 | shard_dir.mkdir(parents=True) 97 | 98 | for i, act in enumerate(activations): 99 | # Save shard 100 | shard_path = shard_dir / f"shard_{i:05d}.safetensors" 101 | save_file({"activations": act}, str(shard_path)) 102 | 103 | # Save metadata 104 | metadata = { 105 | "model_name": model_name, 106 | "layer_idx": layer_idx, 107 | "activation_dim": activations[0].shape[-1], 108 | "num_shards": len(activations), 109 | "context_size": max(act.shape[1] for act in activations), 110 | } 111 | with open(shard_dir / "metadata.json", "w") as f: 112 | json.dump(metadata, f) 113 | 114 | # Verify saved files 115 | assert (shard_dir / "metadata.json").exists(), "Metadata not saved" 116 | assert len(list(shard_dir.glob("shard_*.safetensors"))) == len( 117 | activations 118 | ), "Wrong number of shards" 119 | 120 | # Load and verify a shard 121 | shard_path = shard_dir / "shard_00000.safetensors" 122 | loaded_act = load_file(str(shard_path))["activations"] 123 | assert torch.allclose( 124 | loaded_act, activations[0] 125 | ), "Loaded activation doesn't match" 126 | 127 | # Verify activation properties 128 | d_model = activations[0].shape[-1] 129 | for act in activations: 130 | # Check dimensions 131 | assert act.ndim == 3, f"Wrong activation dimensions: {act.ndim}" 132 | assert ( 133 | act.shape[-1] == d_model 134 | ), f"Inconsistent hidden dimension: {act.shape[-1]} vs {d_model}" 135 | 136 | # Check numerical properties 137 | assert not torch.isnan(act).any(), "NaN in activations" 138 | assert not torch.isinf(act).any(), "Inf in activations" 139 | assert torch.isfinite(act).all(), "Non-finite values in activations" 140 | 141 | print( 142 | f"Generated {len(activations)} activation shards of shape {activations[0].shape}" 143 | ) 144 | 145 | finally: 146 | # Cleanup 147 | for path in Path(tmp_dir).rglob("*"): 148 | if path.is_file(): 149 | path.unlink() 150 | 151 | shutil.rmtree(tmp_dir) 152 | 153 | print("Activation caching test passed!") 154 | 155 | 156 | if __name__ == "__main__": 157 | print("\nRunning LLama and activation tests...") 158 | test_llama_hooks() 159 | test_activation_caching() 160 | print("\nAll tests passed! 🎉\n") 161 | -------------------------------------------------------------------------------- /sae.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Optional, Tuple, Dict 4 | 5 | import click 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | from datasets import load_dataset 10 | from safetensors.torch import load_file, save_file 11 | from tqdm.auto import tqdm 12 | from transformers import AutoModelForCausalLM, AutoTokenizer 13 | 14 | import wandb 15 | 16 | 17 | def cache_activations( 18 | model_name: str, 19 | layer_idx: int, 20 | output_dir: str, 21 | context_length: int = 2048, 22 | num_train_samples: Optional[int] = None, 23 | num_val_samples: Optional[int] = None, 24 | shard_size: int = 10000, 25 | batch_size: int = 8, # Added batch size parameter 26 | device: str = "cuda" if torch.cuda.is_available() else "cpu", 27 | dtype: torch.dtype = torch.float16, 28 | split: str = "train", 29 | ) -> None: 30 | """Cache activations from a specific layer of the LLama model.""" 31 | 32 | print(f"\nCaching activations from {model_name} layer {layer_idx}") 33 | print(f"Saving to: {output_dir}") 34 | 35 | # Load model and tokenizer 36 | print("Loading model and tokenizer...") 37 | model = AutoModelForCausalLM.from_pretrained( 38 | model_name, torch_dtype=dtype, device_map={"": device} 39 | ) 40 | model.eval() 41 | 42 | tokenizer = AutoTokenizer.from_pretrained(model_name) 43 | tokenizer.pad_token = tokenizer.eos_token 44 | 45 | d_model = model.config.hidden_size 46 | 47 | # Load FineWeb dataset 48 | dataset = load_dataset( 49 | "HuggingFaceFW/fineweb", name="sample-10BT", split="train", streaming=False 50 | ) 51 | 52 | if split == "train" and num_train_samples: 53 | dataset = dataset.take(num_train_samples) 54 | elif split == "validation" and num_val_samples: 55 | dataset = dataset.skip(num_train_samples).take(num_val_samples) 56 | 57 | # Create output directory 58 | output_dir = Path(output_dir) 59 | output_dir.mkdir(parents=True, exist_ok=True) 60 | 61 | # Storage for current shard 62 | current_activations = [] 63 | shard_idx = 0 64 | samples_processed = 0 65 | total_tokens_processed = 0 66 | 67 | def save_shard(acts, idx: int, split: str = "train") -> None: 68 | """Save current activations as a shard.""" 69 | if not acts: 70 | return 71 | 72 | shard_dir = output_dir / split 73 | shard_dir.mkdir(exist_ok=True) 74 | shard_path = shard_dir / f"shard_{idx:05d}.safetensors" 75 | 76 | # Stack and save 77 | acts_tensor = torch.cat(acts, dim=0) 78 | save_file({"activations": acts_tensor}, str(shard_path)) 79 | return [] 80 | 81 | def hook_fn(module, input, output): 82 | """Hook to capture activations.""" 83 | act = output.detach().cpu() 84 | 85 | # Validate activation 86 | assert not torch.isnan(act).any(), "NaN in activation" 87 | assert not torch.isinf(act).any(), "Inf in activation" 88 | assert torch.isfinite(act).all(), "Non-finite values in activation" 89 | assert ( 90 | act.shape[-1] == d_model 91 | ), f"Wrong activation dimension: {act.shape[-1]} vs {d_model}" 92 | 93 | current_activations.extend(list(act)) 94 | 95 | # Register hook on the MLP output 96 | hook_registered = False 97 | for name, module in model.named_modules(): 98 | if f"layers.{layer_idx}.mlp.down_proj" in name: 99 | module.register_forward_hook(hook_fn) 100 | hook_registered = True 101 | break 102 | 103 | assert hook_registered, f"Could not find layer {layer_idx}" 104 | 105 | print("\nCollecting activations...") 106 | try: 107 | with torch.no_grad(): 108 | # Process data in batches 109 | for i in tqdm(range(0, len(dataset), batch_size)): 110 | batch_samples = dataset[i : i + batch_size] 111 | 112 | # Tokenize batch 113 | inputs = tokenizer( 114 | batch_samples["text"], 115 | max_length=context_length, 116 | truncation=True, 117 | padding="max_length", 118 | return_tensors="pt", 119 | ) 120 | inputs = {k: v.to(device) for k, v in inputs.items()} 121 | 122 | # Forward pass 123 | model(**inputs) 124 | 125 | # Update counters 126 | samples_processed += len(batch_samples["text"]) 127 | total_tokens_processed += inputs["input_ids"].numel() 128 | 129 | # Save shard if enough activations 130 | if len(current_activations) >= shard_size: 131 | current_activations = save_shard( 132 | current_activations, shard_idx, split 133 | ) 134 | shard_idx += 1 135 | 136 | # Optional early stopping 137 | if ( 138 | split == "train" 139 | and num_train_samples 140 | and samples_processed >= num_train_samples 141 | ): 142 | break 143 | elif ( 144 | split == "validation" 145 | and num_val_samples 146 | and samples_processed >= num_val_samples 147 | ): 148 | break 149 | 150 | except KeyboardInterrupt: 151 | print("\nInterrupted! Saving final shard...") 152 | 153 | finally: 154 | # Save final shard if any activations remain 155 | if current_activations: 156 | save_shard(current_activations, shard_idx, split) 157 | 158 | # Save metadata 159 | metadata = { 160 | "model_name": model_name, 161 | "layer_idx": layer_idx, 162 | "corpus": "HuggingFaceFW/fineweb", 163 | "activation_dim": d_model, 164 | "num_shards": shard_idx + 1, 165 | "samples_processed": samples_processed, 166 | "total_tokens_processed": total_tokens_processed, 167 | "context_length": context_length, 168 | "shard_size": shard_size, 169 | } 170 | 171 | with open(output_dir / split / "metadata.json", "w") as f: 172 | json.dump(metadata, f, indent=2) 173 | 174 | print(f"\nCaching complete!") 175 | print(f"Saved {shard_idx + 1} shards") 176 | print(f"Processed {samples_processed:,} samples") 177 | print(f"Total tokens: {total_tokens_processed:,}") 178 | 179 | 180 | class ActivationLoader: 181 | """Activation data loader that loads N shards at a time and shuffles between them.""" 182 | 183 | def __init__( 184 | self, 185 | cache_dir: str, 186 | split: str, 187 | batch_size: int, 188 | shuffle: bool = True, 189 | num_shards_in_memory: int = 4, 190 | ): 191 | self.cache_dir = Path(cache_dir) / split 192 | self.batch_size = batch_size 193 | self.shuffle = shuffle 194 | self.num_shards_in_memory = num_shards_in_memory 195 | 196 | # Load metadata 197 | with open(self.cache_dir / "metadata.json") as f: 198 | self.metadata = json.load(f) 199 | 200 | self.shard_paths = sorted(self.cache_dir.glob("shard_*.safetensors")) 201 | if self.shuffle: 202 | np.random.shuffle(self.shard_paths) 203 | 204 | self.current_shards = [] 205 | self.current_indices = None 206 | self.next_shard_idx = 0 207 | 208 | def load_shard(self, shard_path: Path) -> torch.Tensor: 209 | """Load a single shard of activations.""" 210 | data = load_file(str(shard_path)) 211 | return data["activations"] 212 | 213 | def load_next_shards(self): 214 | """Load next N shards into memory.""" 215 | self.current_shards = [] 216 | if self.next_shard_idx >= len(self.shard_paths): 217 | return False 218 | 219 | end_idx = min( 220 | self.next_shard_idx + self.num_shards_in_memory, len(self.shard_paths) 221 | ) 222 | 223 | for shard_path in self.shard_paths[self.next_shard_idx : end_idx]: 224 | self.current_shards.append(self.load_shard(shard_path)) 225 | 226 | self.next_shard_idx = end_idx 227 | 228 | # Concatenate shards and create shuffled indices 229 | self.current_data = torch.cat(self.current_shards, dim=0) 230 | if self.shuffle: 231 | self.current_indices = torch.randperm(len(self.current_data)) 232 | else: 233 | self.current_indices = None 234 | 235 | print("Current data", self.current_data.shape) 236 | 237 | return len(self.current_data) > 0 238 | 239 | def __iter__(self): 240 | 241 | self.next_shard_idx = 0 242 | 243 | # Load first batch of shards 244 | has_data = self.load_next_shards() 245 | 246 | while has_data: 247 | # Yield batches from current shards 248 | for start_idx in range(0, len(self.current_data), self.batch_size): 249 | end_idx = min(start_idx + self.batch_size, len(self.current_data)) 250 | 251 | if self.current_indices is not None: 252 | indices = self.current_indices[start_idx:end_idx] 253 | batch = self.current_data[indices] 254 | else: 255 | batch = self.current_data[start_idx:end_idx] 256 | 257 | yield batch 258 | 259 | # Load next batch of shards when current ones are exhausted 260 | has_data = self.load_next_shards() 261 | 262 | 263 | class SparseReLUAutoencoder(nn.Module): 264 | """Sparse ReLU autoencoder. This is typically considered a baseline SAE.""" 265 | 266 | def __init__(self, d_input: int, d_hidden: int, l1_coef: float = 1e-3): 267 | super().__init__() 268 | self.encoder = nn.Linear(d_input, d_hidden) 269 | self.decoder = nn.Linear(d_hidden, d_input) 270 | self.activation = nn.ReLU() 271 | self.l1_coef = l1_coef 272 | self.mse_loss = nn.MSELoss() 273 | 274 | def encode(self, x: torch.Tensor) -> torch.Tensor: 275 | return self.activation(self.encoder(x)) 276 | 277 | def decode(self, h: torch.Tensor) -> torch.Tensor: 278 | return self.decoder(h) 279 | 280 | def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: 281 | h = self.encode(x) 282 | decoded = self.decode(h) 283 | 284 | # Calculate losses 285 | recon_loss = self.mse_loss(decoded.float(), x.float()) 286 | l1_loss = h.abs().mean() 287 | total_loss = recon_loss + l1_loss * self.l1_coef 288 | 289 | with torch.no_grad(): 290 | l0_loss = (h != 0).float().mean() 291 | 292 | return { 293 | "decoded": decoded, 294 | "encoded": h, 295 | "loss": total_loss, 296 | "recon_loss": recon_loss, 297 | "l1_loss": l1_loss, 298 | "l0_loss": l0_loss, 299 | } 300 | 301 | 302 | class SparseTopKAutoEncoder(nn.Module): 303 | """Sparse TopK autoencoder from Scaling and evaluating sparse autoencoders. https://arxiv.org/abs/2406.04093""" 304 | 305 | def __init__(self, d_input: int, d_hidden: int, topk: int = 10): 306 | super().__init__() 307 | self.encoder = nn.Linear(d_input, d_hidden) 308 | self.decoder = nn.Linear(d_hidden, d_input) 309 | self.activation = nn.ReLU() 310 | self.topk = topk 311 | self.mse_loss = nn.MSELoss() 312 | 313 | def encode(self, x: torch.Tensor) -> torch.Tensor: 314 | h = self.activation(self.encoder(x)) 315 | topk_ind = h.topk(k=self.topk, dim=-1, sorted=False).indices 316 | # return topk as is, rest as zero 317 | topk_mask = torch.zeros_like(h) 318 | topk_mask.scatter_(dim=-1, index=topk_ind, value=1) 319 | return h * topk_mask 320 | 321 | def decode(self, h: torch.Tensor) -> torch.Tensor: 322 | return self.decoder(h) 323 | 324 | def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: 325 | h = self.encode(x) 326 | decoded = self.decode(h) 327 | recon_loss = self.mse_loss(decoded.float(), x.float()) 328 | l1_loss = h.abs().mean() 329 | 330 | with torch.no_grad(): 331 | l0_loss = (h != 0).float().mean() 332 | 333 | return { 334 | "decoded": decoded, 335 | "encoded": h, 336 | "loss": recon_loss, 337 | "recon_loss": recon_loss, 338 | "l1_loss": l1_loss, 339 | "l0_loss": l0_loss, 340 | } 341 | 342 | 343 | @click.command() 344 | @click.option("--model-name", default="meta-llama/Llama-2-7b-hf", help="Model name") 345 | @click.option("--layer-idx", default=8, help="Layer index") 346 | @click.option("--cache-dir", default="activation_cache", help="Cache directory") 347 | @click.option("--d-hidden", default=None, type=int, help="Hidden dimension") 348 | @click.option( 349 | "--expansion-factor", default=4, help="Expansion factor if d_hidden not specified" 350 | ) 351 | @click.option("--learning-rate", default=1e-3, help="Learning rate") 352 | @click.option("--batch-size", default=1024, help="Batch size") 353 | @click.option("--shard-size", default=10000, help="Activations per shard") 354 | @click.option("--l1-coef", default=1e-3, help="L1 loss coefficient") 355 | @click.option("--topk", default=100, help="TopK for TopK SAE") 356 | @click.option("--num-epochs", default=10, help="Number of epochs") 357 | @click.option( 358 | "--num-train-samples", default=None, type=int, help="Number of training samples" 359 | ) 360 | @click.option( 361 | "--num-val-samples", default=None, type=int, help="Number of validation samples" 362 | ) 363 | @click.option("--wandb-project", default="sae-training", help="W&B project name") 364 | @click.option("--wandb-run-name", default=None, help="W&B run name") 365 | @click.option("--overwrite-cache", is_flag=True, help="Overwrite cache") 366 | @click.option("--sae-type", default="relu", help="SAE type") 367 | def train_sae( 368 | model_name: str, 369 | layer_idx: int, 370 | cache_dir: str, 371 | d_hidden: Optional[int], 372 | expansion_factor: int, 373 | learning_rate: float, 374 | batch_size: int, 375 | shard_size: int, 376 | l1_coef: float, 377 | topk: int, 378 | num_epochs: int, 379 | num_train_samples: Optional[int], 380 | num_val_samples: Optional[int], 381 | wandb_project: str, 382 | wandb_run_name: str, 383 | overwrite_cache: bool = False, 384 | sae_type: str = "relu", 385 | ) -> None: 386 | """Train a sparse autoencoder on LLama activations with validation.""" 387 | 388 | # Initialize wandb 389 | config = { 390 | "model_name": model_name, 391 | "layer_idx": layer_idx, 392 | "d_hidden": d_hidden, 393 | "expansion_factor": expansion_factor, 394 | "learning_rate": learning_rate, 395 | "batch_size": batch_size, 396 | "shard_size": shard_size, 397 | "l1_coef": l1_coef, 398 | "topk": topk, 399 | "num_epochs": num_epochs, 400 | "num_train_samples": num_train_samples, 401 | "num_val_samples": num_val_samples, 402 | "sae_type": sae_type, 403 | } 404 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 405 | 406 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 407 | cache_dir = Path(cache_dir) 408 | 409 | # Cache activations if not already cached 410 | if not (cache_dir / "train").exists() or overwrite_cache: 411 | cache_activations( 412 | model_name=model_name, 413 | layer_idx=layer_idx, 414 | output_dir=cache_dir, 415 | context_length=2048, 416 | shard_size=shard_size, 417 | num_train_samples=num_train_samples, 418 | num_val_samples=None, 419 | device=device, 420 | dtype=torch.bfloat16, 421 | split="train", 422 | ) 423 | 424 | if not (cache_dir / "validation").exists() or overwrite_cache: 425 | cache_activations( 426 | model_name=model_name, 427 | layer_idx=layer_idx, 428 | output_dir=cache_dir, 429 | context_length=2048, 430 | shard_size=shard_size, 431 | num_train_samples=num_train_samples, 432 | num_val_samples=num_val_samples, 433 | device=device, 434 | dtype=torch.bfloat16, 435 | split="validation", 436 | ) 437 | 438 | # Load metadata 439 | with open(cache_dir / "train" / "metadata.json") as f: 440 | metadata = json.load(f) 441 | 442 | # Initialize SAE 443 | d_input = metadata["activation_dim"] 444 | if d_hidden is None: 445 | d_hidden = d_input * expansion_factor 446 | 447 | print(f"Initializing SAE with input dim {d_input} and hidden dim {d_hidden}") 448 | 449 | if sae_type == "relu": 450 | sae = SparseReLUAutoencoder(d_input, d_hidden, l1_coef=l1_coef).to(device) 451 | elif sae_type == "topk": 452 | sae = SparseTopKAutoEncoder(d_input, d_hidden, topk=topk).to(device) 453 | 454 | # initialize bias = 0, dec as enc transpose 455 | sae.decoder.bias.data.zero_() 456 | sae.decoder.weight.data = sae.encoder.weight.data.T 457 | 458 | # Create data loaders 459 | train_loader = ActivationLoader( 460 | cache_dir, "train", batch_size=batch_size, shuffle=True 461 | ) 462 | val_loader = ActivationLoader( 463 | cache_dir, "validation", batch_size=batch_size, shuffle=False 464 | ) 465 | 466 | # Training setup 467 | optimizer = torch.optim.Adam(sae.parameters(), lr=learning_rate) 468 | 469 | ctx = torch.amp.autocast(device_type=device.type, dtype=torch.bfloat16) 470 | 471 | def evaluate(): 472 | sae.eval() 473 | total_recon_loss = 0.0 474 | total_l1_loss = 0.0 475 | total_l0_loss = 0.0 476 | total_loss = 0.0 477 | num_batches = 0 478 | 479 | with torch.no_grad(): 480 | for batch in val_loader: 481 | with ctx: 482 | batch = batch.to(device, dtype=torch.bfloat16) 483 | output = sae(batch) 484 | 485 | total_recon_loss += output["recon_loss"].item() 486 | total_l1_loss += output["l1_loss"].item() 487 | total_l0_loss += output["l0_loss"].item() 488 | total_loss += output["loss"].item() 489 | num_batches += 1 490 | 491 | return { 492 | "val_loss": total_loss / num_batches, 493 | "val_recon_loss": total_recon_loss / num_batches, 494 | "val_l1_loss": total_l1_loss / num_batches, 495 | "val_l0_loss": total_l0_loss / num_batches, 496 | } 497 | 498 | # Training loop 499 | print("Starting training...") 500 | best_val_loss = float("inf") 501 | 502 | for epoch in range(num_epochs): 503 | sae.train() 504 | total_loss = 0.0 505 | recon_loss = 0.0 506 | l1_loss = 0.0 507 | batch_count = 0 508 | 509 | for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"): 510 | optimizer.zero_grad() 511 | 512 | with ctx: 513 | batch = batch.to(device, dtype=torch.bfloat16) 514 | output = sae(batch) 515 | 516 | output["loss"].backward() 517 | optimizer.step() 518 | 519 | total_loss += output["loss"].item() 520 | recon_loss += output["recon_loss"].item() 521 | l1_loss += output["l1_loss"].item() 522 | batch_count += 1 523 | 524 | # Log training metrics periodically 525 | if batch_count % 100 == 1: 526 | metrics = { 527 | "train_loss": output["loss"].item(), 528 | "train_recon_loss": output["recon_loss"].item(), 529 | "train_l1_loss": output["l1_loss"].item(), 530 | "train_l0_loss": output["l0_loss"].item(), 531 | "epoch": epoch, 532 | "batch": batch_count, 533 | } 534 | wandb.log(metrics) 535 | 536 | # Evaluate on validation set 537 | val_metrics = evaluate() 538 | val_metrics["epoch"] = epoch 539 | wandb.log(val_metrics) 540 | 541 | # Save best model 542 | if val_metrics["val_loss"] < best_val_loss: 543 | best_val_loss = val_metrics["val_loss"] 544 | save_path = cache_dir / "best_sae.pt" 545 | torch.save( 546 | { 547 | "epoch": epoch, 548 | "model_state_dict": sae.state_dict(), 549 | "optimizer_state_dict": optimizer.state_dict(), 550 | "validation_loss": best_val_loss, 551 | "config": config, 552 | }, 553 | save_path, 554 | ) 555 | 556 | # Log best model metrics 557 | wandb.run.summary["best_val_loss"] = best_val_loss 558 | wandb.run.summary["best_model_epoch"] = epoch 559 | 560 | # Print epoch statistics 561 | avg_train_loss = total_loss / batch_count 562 | print(f"Epoch {epoch+1}") 563 | print(f"Train - Loss: {avg_train_loss:.4f}") 564 | print(f"Val - Loss: {val_metrics['val_loss']:.4f}") 565 | print(f"Best val loss: {best_val_loss:.4f}") 566 | 567 | # Save final model 568 | final_save_path = cache_dir / "final_sae.pt" 569 | torch.save( 570 | { 571 | "epoch": num_epochs, 572 | "model_state_dict": sae.state_dict(), 573 | "optimizer_state_dict": optimizer.state_dict(), 574 | "validation_loss": val_metrics["val_loss"], 575 | "config": config, 576 | }, 577 | final_save_path, 578 | ) 579 | 580 | print(f"Training complete!") 581 | print(f"Best model saved to {save_path}") 582 | print(f"Final model saved to {final_save_path}") 583 | 584 | # Close wandb run 585 | wandb.finish() 586 | 587 | 588 | if __name__ == "__main__": 589 | train_sae() 590 | --------------------------------------------------------------------------------