├── .gitignore ├── .gitmodules ├── .vscode └── settings.json ├── README.md ├── auto_interp ├── compare_dictionaries.py ├── generate_scripts.py ├── results.pkl └── sae_feature_evals.py ├── compare_geometry ├── examine_duplicate_features.py ├── find_duplicate_features.py └── paper_plots.py ├── config.py ├── dictionary_learning ├── .DS_Store ├── .gitignore ├── __init__.py ├── buffer.py ├── config.py ├── dictionary.py ├── evaluation.py ├── grad_pursuit.py ├── interp.py ├── kernels.py ├── pretrained_dictionary_downloader.sh ├── requirements.txt ├── trainers │ ├── gdm.py │ ├── jump_relu.py │ ├── moe.py │ ├── standard.py │ ├── standard_new.py │ ├── switch.py │ ├── switch1on.py │ ├── top_k.py │ └── trainer.py ├── training.py └── utils.py ├── other └── lobes.py ├── results ├── 1on │ ├── 1on.ipynb │ ├── l0_deltace.png │ ├── l0_mse.png │ ├── l0_recovered.png │ ├── primary-switch-fast.csv │ └── switch-heavy.csv ├── 1on_lb │ ├── 1on_lb.csv │ ├── alpha_deltace.png │ ├── alpha_lossrec.png │ ├── alpha_mse.png │ └── lb.ipynb ├── efficiency │ ├── efficiency-switch.csv │ ├── efficiency-topk.csv │ ├── efficiency.ipynb │ └── efficiency.png ├── gated_lr │ ├── gated_l0_lossrec.png │ ├── gated_l0_mse.png │ ├── gated_mse_deltace.png │ ├── primary-gated-1e-3.csv │ ├── primary-gated-3e-4.csv │ ├── primary-gated-5e-5.csv │ └── primary-gated.ipynb ├── heaviside_softmax │ ├── experts.ipynb │ ├── experts16h.csv │ ├── experts16s.csv │ ├── experts_l0_deltace.png │ ├── experts_l0_lossrec.png │ ├── experts_l0_mse.png │ └── primary-experts.csv ├── load_balance │ ├── alpha_deltace.png │ ├── alpha_lossrec.png │ ├── alpha_mse.pdf │ ├── alpha_mse.png │ ├── lb-sweep.csv │ └── lb.ipynb └── primary │ ├── big-plot.ipynb │ ├── flopmatch_l0_deltace.png │ ├── flopmatch_l0_lossrec.png │ ├── flopmatch_l0_mse.png │ ├── flopmatch_mse_deltace.png │ ├── l0_deltace.png │ ├── l0_lossrec.png │ ├── l0_mse.png │ ├── mse_deltace.png │ ├── primary-flop-match.ipynb │ ├── primary-gated-clean.csv │ ├── primary-gated.csv │ ├── primary-relu-clean.csv │ ├── primary-relu.csv │ ├── primary-switch-fast.csv │ ├── primary-switch-flop.csv │ ├── primary-topk-clean.csv │ ├── primary-topk.csv │ ├── primary.ipynb │ ├── switch_sae_l0_lr.pdf │ ├── switch_sae_l0_mse.pdf │ ├── switch_sae_pareto.pdf │ ├── switch_sae_pareto_flop.pdf │ └── switch_sae_pareto_width.pdf ├── save_activations.py ├── scaling_laws └── attempt0 │ ├── plot.ipynb │ ├── train-dense-topk.py │ ├── train-dense-topk.sh │ ├── train-switch-topk.py │ └── train-switch-topk.sh ├── speed.ipynb ├── table ├── create_table_script.py ├── eval_table.py ├── modified_eval.py ├── run_parallel.sh └── train_switch_table.py ├── train-gated.py ├── train-jump.py ├── train-moe.py ├── train-relu.py ├── train-switch-1on.py ├── train-switch-flop.py ├── train-switch.py └── train-topk.py /.gitignore: -------------------------------------------------------------------------------- 1 | dictionaries* 2 | wandb 3 | .DS_Store 4 | weights* 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | 167 | # venv 168 | switch 169 | 170 | # plots 171 | plots 172 | *.png 173 | 174 | # data 175 | data 176 | 177 | # Pytorch bins 178 | *.pt 179 | 180 | # Pickle filees 181 | *.pkl 182 | !auto_interp/*.pkl 183 | 184 | # Generated shell files 185 | auto_interp/*.sh 186 | 187 | other_dictionaries 188 | 189 | table/results.csv -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "sae-auto-interp"] 2 | path = sae-auto-interp 3 | url = https://github.com/EleutherAI/sae-auto-interp/ 4 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.analysis.extraPaths": [ 3 | "./sae-auto-interp" 4 | ] 5 | } -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Efficient Dictionary Learning with Switch Sparse Autoencoders (SAEs) 3 |

4 | 5 | More soon! 6 | 7 | ## Credits 8 | This repository is adapted from [dictionary_learning](https://github.com/saprmarks/dictionary_learning) by Samuel Marks and Aaron Mueller. 9 | -------------------------------------------------------------------------------- /auto_interp/compare_dictionaries.py: -------------------------------------------------------------------------------- 1 | # %% 2 | 3 | import pickle 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import os 7 | from matplotlib.ticker import FixedLocator, FixedFormatter 8 | 9 | flop_matched_color_dict = { 10 | 1: "#9467bd", 11 | 2: "#ff7f0e", 12 | 4: "#ff6600", 13 | 8: "#ff3300", 14 | } 15 | 16 | fixed_width_color_dict = { 17 | 1: "#9467bd", 18 | 16: "#ff7f0e", 19 | 32: "#ff6600", 20 | 64: "#ff4d00", 21 | 128: "#ff3300", 22 | } 23 | 24 | dictionaries = [ 25 | "dictionaries/topk/k64", 26 | "dictionaries/fixed-width/16_experts/k64", 27 | "dictionaries/fixed-width/32_experts/k64", 28 | "dictionaries/fixed-width/64_experts/k64", 29 | "dictionaries/fixed-width/128_experts/k64", 30 | "dictionaries/flop-matched/2_experts/k64", 31 | "dictionaries/flop-matched/4_experts/k64", 32 | "dictionaries/flop-matched/8_experts/k64", 33 | ] 34 | results = pickle.load(open("results.pkl", "rb")) 35 | 36 | def confidence_interval(successes, total, z=1.96): 37 | p = successes / total 38 | se = np.sqrt(p * (1 - p) / total) 39 | return z * se 40 | 41 | def process_results(results, dictionary): 42 | total_per_quantile = [0 for _ in range(11)] 43 | total_correct_per_quantile = [0 for _ in range(11)] 44 | for quantile_positives, quantile_totals in results[0]: 45 | for i in range(11): 46 | total_per_quantile[i] += quantile_totals[i] 47 | total_correct_per_quantile[i] += quantile_positives[i] 48 | 49 | total_negative = 0 50 | total_negative_correct = 0 51 | for negative_positives, negative_totals in results[1]: 52 | total_negative += negative_totals 53 | total_negative_correct += negative_positives 54 | 55 | average_per_quantile = [ 56 | total_correct_per_quantile[i] / total_per_quantile[i] for i in range(1, 11) 57 | ] 58 | average_negative = total_negative_correct / total_negative 59 | 60 | ci_per_quantile = [ 61 | confidence_interval(total_correct_per_quantile[i], total_per_quantile[i]) 62 | for i in range(1, 11) 63 | ] 64 | ci_negative = confidence_interval(total_negative_correct, total_negative) 65 | 66 | x = ["Not"] + ["Q" + str(i) for i in range(1, 11)] 67 | y = [1 - average_negative] + average_per_quantile[::-1] 68 | ci = [ci_negative] + ci_per_quantile[::-1] 69 | 70 | return x, y, ci, dictionary 71 | 72 | # Create two subplots with the new figure size 73 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(5.5, 2.5)) 74 | 75 | topk_data = None 76 | 77 | for results, dictionary in zip(results, dictionaries): 78 | x, y, ci, dict_name = process_results(results, dictionary) 79 | 80 | for i in [2, 4, 8, 16, 32, 64, 128]: 81 | if f"/{i}_experts" in dict_name: 82 | label_name = f"Switch: {i}e" 83 | break 84 | if "topk" in dict_name: 85 | label_name = "Topk" 86 | 87 | if "topk" in dict_name: 88 | topk_data = (x, y, ci, dict_name) 89 | color = flop_matched_color_dict[1] # Use color for 1 expert 90 | ax1.errorbar(x, y, yerr=ci, fmt='-o', capsize=5, label=label_name, color=color, markersize=5) 91 | ax2.errorbar(x, y, yerr=ci, fmt='-o', capsize=5, label=label_name, color=color, markersize=5) 92 | elif "flop-matched" in dict_name: 93 | ax = ax1 94 | color_dict = flop_matched_color_dict 95 | expert_count = int(dict_name.split("/")[2].split("_")[0]) 96 | color = color_dict.get(expert_count, "#000000") # Default to black if not found 97 | ax.errorbar(x, y, yerr=ci, fmt='-o', capsize=5, label=label_name, color=color, markersize=5) 98 | else: 99 | ax = ax2 100 | color_dict = fixed_width_color_dict 101 | expert_count = int(dict_name.split("/")[2].split("_")[0]) 102 | color = color_dict.get(expert_count, "#000000") # Default to black if not found 103 | ax.errorbar(x, y, yerr=ci, fmt='-o', capsize=5, label=label_name, color=color, markersize=5) 104 | 105 | for ax in (ax1, ax2): 106 | ax.set_xlabel("Quantiles", fontsize=8) 107 | ax.set_ylabel('Accuracy', fontsize=8) 108 | ax.grid(True, which="both", ls="--", linewidth=0.5) 109 | ax.legend(loc='upper center', prop={'size': 5.5}) 110 | 111 | # Set y-axis limits and ticks 112 | ax.set_ylim(0.2, 1.0) # Adjust these values as needed 113 | 114 | # Adjust tick label size 115 | ax.tick_params(axis='both', labelsize=6.5) 116 | 117 | # Remove minor ticks 118 | ax.minorticks_off() 119 | 120 | # Tilt all x-axis labels 121 | for tick in ax.get_xticklabels(): 122 | tick.set_rotation(45) 123 | tick.set_va('top') 124 | 125 | 126 | # Turn off ticks on y axis of ax2 127 | ax2.yaxis.set_ticks_position('none') 128 | ax2.yaxis.set_tick_params(size=0) 129 | ax2.yaxis.set_ticklabels([]) 130 | ax2.set_ylabel('') 131 | 132 | ax1.set_title("FLOP-Matched", fontsize=9) 133 | ax2.set_title("Width-Matched", fontsize=9) 134 | 135 | os.makedirs("plots", exist_ok=True) 136 | plt.savefig("plots/detection_split.pdf", bbox_inches='tight') 137 | plt.show() 138 | # %% 139 | -------------------------------------------------------------------------------- /auto_interp/generate_scripts.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import glob 3 | 4 | import os 5 | 6 | # Change dir to two dirs up 7 | file_path = os.path.abspath(__file__) 8 | os.chdir(os.path.dirname(os.path.dirname(file_path))) 9 | 10 | # Get all paths in dictionaries 11 | dictionaries = glob.glob("dictionaries/**/*.pt", recursive=True) 12 | parent_dirs = set(["/".join(d.split("/")[:-1]) for d in dictionaries]) 13 | 14 | # Change dir back 15 | os.chdir(os.path.dirname(file_path)) 16 | 17 | # Write out run_all_feature_eval_generate.sh 18 | with open("run_all_feature_eval_generate.sh", "w") as f: 19 | for parent_dir in parent_dirs: 20 | f.write(f"python sae_feature_evals.py --sae_path {parent_dir} --to_do generate\n") 21 | 22 | # Write out run_all_feature_eval_eval.sh 23 | with open("run_all_feature_eval_eval.sh", "w") as f: 24 | for parent_dir in parent_dirs: 25 | f.write(f"python sae_feature_evals.py --sae_path {parent_dir} --to_do eval\n") 26 | 27 | # Write out run_all_feature_eval_generate_small.sh 28 | # Only k = 64, and (fixed-width or topk) 29 | with open("run_all_feature_eval_generate_small.sh", "w") as f: 30 | for parent_dir in parent_dirs: 31 | if "k64" in parent_dir: 32 | f.write(f"python sae_feature_evals.py --sae_path {parent_dir} --to_do generate\n") 33 | 34 | # Write out run_all_feature_eval_eval_small.sh 35 | # Only k = 64, and (fixed-width or topk) 36 | with open("run_all_feature_eval_eval_small.sh", "w") as f: 37 | for parent_dir in parent_dirs: 38 | if "k64" in parent_dir: 39 | f.write(f"python sae_feature_evals.py --sae_path {parent_dir} --to_do eval\n") -------------------------------------------------------------------------------- /auto_interp/results.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/auto_interp/results.pkl -------------------------------------------------------------------------------- /auto_interp/sae_feature_evals.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import sys 4 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 5 | 6 | from functools import partial 7 | from nnsight import LanguageModel 8 | import asyncio 9 | 10 | from dictionary_learning.trainers.switch import SwitchAutoEncoder 11 | from dictionary_learning.trainers.top_k import AutoEncoderTopK 12 | 13 | import torch 14 | import einops 15 | import os 16 | 17 | from functools import partial 18 | 19 | import torch 20 | from nnsight import LanguageModel 21 | 22 | from sae_auto_interp.autoencoders.wrapper import AutoencoderLatents 23 | from sae_auto_interp.autoencoders.OpenAI import Autoencoder 24 | 25 | 26 | from sae_auto_interp.features import FeatureCache 27 | from sae_auto_interp.features.features import FeatureRecord 28 | from sae_auto_interp.utils import load_tokenized_data 29 | 30 | 31 | from sae_auto_interp.features import FeatureDataset, pool_max_activation_windows, random_activation_windows, sample 32 | from sae_auto_interp.config import FeatureConfig, ExperimentConfig 33 | 34 | 35 | from sae_auto_interp.explainers import SimpleExplainer 36 | from sae_auto_interp.scorers import RecallScorer 37 | from tqdm import tqdm 38 | import pickle 39 | 40 | 41 | from sae_auto_interp.clients import OpenRouter, Local 42 | from sae_auto_interp.utils import display 43 | import argparse 44 | 45 | # %% 46 | 47 | # Change dir to folder one level up from this file 48 | this_dir = os.path.dirname(os.path.abspath(__file__)) 49 | one_level_up = os.path.dirname(this_dir) 50 | os.chdir(one_level_up) 51 | 52 | # %% 53 | 54 | CTX_LEN = 128 55 | BATCH_SIZE = 32 56 | N_TOKENS = 10_000_000 57 | N_SPLITS = 2 58 | NUM_FEATURES_TO_TEST = 1000 59 | 60 | device = "cuda:1" 61 | 62 | # Set torch seed 63 | torch.manual_seed(0) 64 | 65 | # %% 66 | 67 | try: 68 | from IPython import get_ipython # type: ignore 69 | 70 | ipython = get_ipython() 71 | assert ipython is not None 72 | ipython.run_line_magic("load_ext", "autoreload") 73 | ipython.run_line_magic("autoreload", "2") 74 | 75 | is_notebook = True 76 | except: 77 | is_notebook = False 78 | 79 | if not is_notebook: 80 | argparser = argparse.ArgumentParser() 81 | argparser.add_argument("--sae_path", type=str, required=True) 82 | argparser.add_argument("--to_do", type=str, required=True, choices=["generate", "eval", "both"]) 83 | 84 | args = argparser.parse_args() 85 | SAE_PATH = args.sae_path 86 | to_do = args.to_do 87 | else: 88 | SAE_PATH = "dictionaries/topk/k64" 89 | to_do = "both" 90 | # SAE_PATH = "dictionaries/fixed-width/16_experts/k64" 91 | 92 | RAW_ACTIVATIONS_DIR = f"/media/jengels/sda/switch/{SAE_PATH}" 93 | SAVE_FILE = f"/media/jengels/sda/switch/{SAE_PATH}/results.pkl" 94 | FINAL_SAVE_FILE = f"/media/jengels/sda/switch/{SAE_PATH}/final_results.pkl" 95 | 96 | # %% 97 | 98 | 99 | if "topk" in SAE_PATH: 100 | k = int(SAE_PATH.split("/")[-1][1:]) 101 | ae = AutoEncoderTopK.from_pretrained(f"{SAE_PATH}/ae.pt", k=k, device=device) 102 | else: 103 | num_experts = int(SAE_PATH.split("/")[-2].split("_")[0]) 104 | k = int(SAE_PATH.split("/")[-1][1:]) 105 | ae = SwitchAutoEncoder.from_pretrained(f"{SAE_PATH}/ae.pt", k=k, experts=num_experts, device=device) 106 | 107 | ae.to(device) 108 | 109 | model = LanguageModel("openai-community/gpt2", device_map=device, dispatch=True) 110 | 111 | # TODO: Ideally use openwebtext 112 | tokens = load_tokenized_data( 113 | CTX_LEN, 114 | model.tokenizer, 115 | "kh4dien/fineweb-100m-sample", 116 | "train[:15%]", 117 | ) 118 | 119 | # %% 120 | 121 | 122 | WIDTH = ae.dict_size 123 | # Get NUM_FEATURES_TO_TEST random features to test without replacement 124 | random_features = torch.randperm(WIDTH)[:NUM_FEATURES_TO_TEST] 125 | 126 | # %% 127 | 128 | generate = to_do in ["generate", "both"] 129 | if generate: 130 | 131 | def _forward(ae, x): 132 | _, _, top_acts, top_indices = ae.forward(x, output_features="all") 133 | 134 | expanded = torch.zeros(top_acts.shape[0], WIDTH, device=device) 135 | expanded.scatter_(1, top_indices, top_acts) 136 | 137 | expanded = einops.rearrange(expanded, "(b c) w -> b c w", b=x.shape[0], c=x.shape[1]) 138 | return expanded 139 | 140 | # We can simply add the new module as an attribute to an existing 141 | # submodule on GPT-2's module tree. 142 | submodule = model.transformer.h[8] 143 | submodule.ae = AutoencoderLatents( 144 | ae, 145 | partial(_forward, ae), 146 | width=ae.dict_size 147 | ) 148 | 149 | with model.edit(" ", inplace=True): 150 | acts = submodule.output[0] 151 | submodule.ae(acts, hook=True) 152 | 153 | with model.trace("hello, my name is"): 154 | latents = submodule.ae.output.save() 155 | 156 | module_path = submodule.path 157 | 158 | submodule_dict = {module_path : submodule} 159 | module_filter = {module_path : random_features.to(device)} 160 | 161 | cache = FeatureCache( 162 | model, 163 | submodule_dict, 164 | batch_size=BATCH_SIZE, 165 | filters=module_filter 166 | ) 167 | 168 | cache.run(N_TOKENS, tokens) 169 | 170 | cache.save_splits( 171 | n_splits=N_SPLITS, 172 | save_dir=RAW_ACTIVATIONS_DIR, 173 | ) 174 | 175 | # %% 176 | 177 | 178 | cfg = FeatureConfig( 179 | width = WIDTH, 180 | min_examples = 200, 181 | max_examples = 2_000, 182 | example_ctx_len = CTX_LEN, 183 | n_splits = 2, 184 | ) 185 | 186 | sample_cfg = ExperimentConfig(n_random=50) 187 | 188 | # This is a hack because this isn't currently defined in the repo 189 | sample_cfg.chosen_quantile = 0 190 | 191 | dataset = FeatureDataset( 192 | raw_dir=RAW_ACTIVATIONS_DIR, 193 | cfg=cfg, 194 | ) 195 | # %% 196 | 197 | if to_do == "generate": 198 | exit() 199 | 200 | # %% 201 | 202 | # Define these functions here so we don't need to edit the functions in the git submodule 203 | def default_constructor( 204 | record, 205 | tokens, 206 | buffer_output, 207 | n_random: int, 208 | cfg: FeatureConfig 209 | ): 210 | pool_max_activation_windows( 211 | record, 212 | tokens=tokens, 213 | buffer_output=buffer_output, 214 | cfg=cfg 215 | ) 216 | 217 | random_activation_windows( 218 | record, 219 | tokens=tokens, 220 | buffer_output=buffer_output, 221 | n_random=n_random, 222 | ctx_len=cfg.example_ctx_len, 223 | ) 224 | 225 | 226 | constructor=partial( 227 | default_constructor, 228 | n_random=sample_cfg.n_random, 229 | tokens=tokens, 230 | cfg=cfg 231 | ) 232 | 233 | sampler = partial( 234 | sample, 235 | cfg=sample_cfg 236 | ) 237 | 238 | 239 | def load( 240 | dataset, 241 | constructor, 242 | sampler, 243 | transform = None 244 | ): 245 | def _process(buffer_output): 246 | record = FeatureRecord(buffer_output.feature) 247 | if constructor is not None: 248 | constructor(record=record, buffer_output=buffer_output) 249 | 250 | if sampler is not None: 251 | sampler(record) 252 | 253 | if transform is not None: 254 | transform(record) 255 | 256 | return record 257 | 258 | for buffer in dataset.buffers: 259 | for data in buffer: 260 | if data is not None: 261 | yield _process(data) 262 | # %% 263 | 264 | record_iterator = load(constructor=constructor, sampler=sampler, dataset=dataset, transform=None) 265 | 266 | # next_record = next(record_iterator) 267 | 268 | # display(next_record, model.tokenizer, n=5) 269 | # %% 270 | 271 | # Command to run: vllm serve hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4 --max_model_len 10000 --tensor-parallel-size 2 272 | client = Local("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4") 273 | 274 | # %% 275 | 276 | async def run_async(): 277 | 278 | global record_iterator 279 | 280 | 281 | positive_scores = [] 282 | negative_scores = [] 283 | explanations = [] 284 | feature_ids = [] 285 | total_positive_score = 0 286 | total_negative_score = 0 287 | total_evaluated = 0 288 | 289 | bar = tqdm(record_iterator, total=NUM_FEATURES_TO_TEST) 290 | 291 | 292 | 293 | for record in bar: 294 | 295 | explainer = SimpleExplainer( 296 | client, 297 | model.tokenizer, 298 | # max_new_tokens=50, 299 | max_tokens=50, 300 | temperature=0.0 301 | ) 302 | 303 | explainer_result = await explainer(record) 304 | # explainer_result = asyncio.run(explainer(record)) 305 | 306 | # print(explainer_result.explanation) 307 | record.explanation = explainer_result.explanation 308 | 309 | 310 | scorer = RecallScorer( 311 | client, 312 | model.tokenizer, 313 | max_tokens=25, 314 | temperature=0.0, 315 | batch_size=4, 316 | ) 317 | 318 | 319 | score = await scorer(record) 320 | 321 | quantile_positives = [0 for _ in range(11)] 322 | quantile_totals = [0 for _ in range(11)] 323 | negative_positives = 0 324 | negative_totals = 0 325 | for score_instance in score.score: 326 | quantile = score_instance.distance 327 | if quantile != -1 and score_instance.prediction != -1: 328 | quantile_totals[quantile] += 1 329 | if score_instance.prediction == 1: 330 | quantile_positives[quantile] += 1 331 | if quantile == -1 and score_instance.prediction != -1: 332 | negative_totals += 1 333 | if score_instance.prediction == 1: 334 | negative_positives += 1 335 | 336 | positive_scores.append((quantile_positives, quantile_totals)) 337 | negative_scores.append((negative_positives, negative_totals)) 338 | 339 | if (sum(quantile_totals) == 0) or (negative_totals == 0): 340 | continue 341 | 342 | total_positive_score += sum(quantile_positives) / sum(quantile_totals) 343 | total_negative_score += negative_positives / negative_totals 344 | total_evaluated += 1 345 | 346 | bar.set_description(f"Positive Recall: {total_positive_score / total_evaluated}, Negative Recall: {total_negative_score / total_evaluated}") 347 | 348 | print(quantile_positives, quantile_totals) 349 | 350 | explanations.append(record.explanation) 351 | 352 | feature_ids.append(record.feature.feature_index) 353 | 354 | 355 | with open(SAVE_FILE, "wb") as f: 356 | pickle.dump((positive_scores, negative_scores, explanations, feature_ids), f) 357 | 358 | with open(FINAL_SAVE_FILE, "wb") as f: 359 | pickle.dump((positive_scores, negative_scores, explanations, feature_ids), f) 360 | 361 | 362 | # Switch comment when running in notebook/command line 363 | # await run_async() 364 | asyncio.run(run_async()) 365 | 366 | # %% 367 | -------------------------------------------------------------------------------- /compare_geometry/examine_duplicate_features.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import sys 4 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 5 | 6 | import pandas as pd 7 | import torch 8 | from dictionary_learning.trainers.switch import SwitchAutoEncoder 9 | from dictionary_learning.trainers.top_k import AutoEncoderTopK 10 | import einops 11 | from tqdm import tqdm 12 | from transformers import GPT2Tokenizer 13 | 14 | torch.set_grad_enabled(False) 15 | 16 | # %% 17 | 18 | # device = "cuda:0" 19 | device = "cpu" 20 | duplicate_features = pd.read_csv('../data/duplicates.csv') 21 | data = torch.load("../data/gpt2_activations_layer8.pt", map_location=device) 22 | tokens = torch.load("../data/gpt2_tokens.pt", map_location=device) 23 | save_location = "../data/top_activating_for_dupes_layer8.pt" 24 | ctx_len = 128 25 | 26 | # %% 27 | 28 | batch_size = 256 29 | store_topk_activating = 100 30 | 31 | unique_num_experts = duplicate_features['num_experts'].unique() 32 | unique_topks = duplicate_features['k'].unique() 33 | 34 | ae_details_to_top_activating = {} 35 | 36 | 37 | for num_experts in unique_num_experts: 38 | for k in tqdm(unique_topks): 39 | filtered_df = duplicate_features[(duplicate_features['num_experts'] == num_experts) & (duplicate_features['k'] == k)] 40 | feature_index_to_num_dupes = filtered_df.set_index('feature_index')['num_dupes'].to_dict() 41 | feature_index_tensor = torch.tensor(list(feature_index_to_num_dupes.keys()), device=device) 42 | feature_index_to_topk_activating = {feature_index: [] for feature_index in feature_index_to_num_dupes.keys()} 43 | 44 | if feature_index_tensor.numel() == 0: 45 | ae_details_to_top_activating[(num_experts, k)] = {} 46 | continue 47 | 48 | if num_experts == 1: 49 | ae = AutoEncoderTopK.from_pretrained(f"../dictionaries/topk/k{k}/ae.pt", k=k, device=device) 50 | else: 51 | ae = SwitchAutoEncoder.from_pretrained(f"../dictionaries/fixed-width/{num_experts}_experts/k{k}/ae.pt", k=k, experts=num_experts, device=device) 52 | 53 | key = (num_experts, k) 54 | 55 | for batch_start in range(0, len(data), batch_size): 56 | batch = data[batch_start:batch_start+batch_size].to(device) 57 | batch_tokens = tokens[batch_start:batch_start+batch_size].to(device) 58 | 59 | top_acts, top_indices = ae.forward(batch, output_features="all")[2:] 60 | activations = torch.zeros(batch_tokens.shape[0] * batch_tokens.shape[1], len(ae.decoder), device=device) 61 | 62 | activations.scatter_(1, top_indices, top_acts) 63 | dupe_feature_activations = activations[..., feature_index_tensor] 64 | top_activating = torch.topk(dupe_feature_activations, store_topk_activating, dim=0) 65 | for j, feature_index in enumerate(feature_index_tensor): 66 | top_activating_indices = top_activating.indices[:, j] 67 | top_activating_values = top_activating.values[:, j] 68 | top_activating_contexts = top_activating_indices // ctx_len + batch_start 69 | top_activating_token_ids = top_activating_indices % ctx_len 70 | feature_index_to_topk_activating[feature_index.item()].extend(list(zip(top_activating_contexts.cpu().numpy(), top_activating_token_ids.cpu().numpy(), top_activating_values.cpu().numpy()))) 71 | 72 | # Get the top k activations for each feature index 73 | for feature_index, top_activating in feature_index_to_topk_activating.items(): 74 | num_dupes = feature_index_to_num_dupes[feature_index] 75 | top_activating.sort(key=lambda x: x[2], reverse=True) 76 | feature_index_to_topk_activating[feature_index] = (top_activating[:store_topk_activating], num_dupes) 77 | 78 | ae_details_to_top_activating[key] = feature_index_to_topk_activating 79 | 80 | torch.save(ae_details_to_top_activating, save_location) 81 | 82 | # %% 83 | 84 | # Load the top activating for duplicates 85 | ae_details_to_top_activating = torch.load(save_location, map_location=device) 86 | 87 | context_limit = 15 88 | 89 | # Load GPT2 tokenizer 90 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 91 | 92 | for (num_experts, k), feature_index_to_topk_activating in tqdm(ae_details_to_top_activating.items()): 93 | output_file = f"../data/top_activating_for_dupes_layer8_num-experts{num_experts}_topk{k}.txt" 94 | 95 | if num_experts == 1: 96 | ae = AutoEncoderTopK.from_pretrained(f"../dictionaries/topk/k{k}/ae.pt", k=k, device=device) 97 | else: 98 | ae = SwitchAutoEncoder.from_pretrained(f"../dictionaries/fixed-width/{num_experts}_experts/k{k}/ae.pt", k=k, experts=num_experts, device=device) 99 | 100 | decoder_vecs = ae.decoder 101 | decoder_vecs = decoder_vecs / decoder_vecs.norm(dim=-1, keepdim=True) 102 | sims = decoder_vecs @ decoder_vecs.T 103 | threshold = 0.9 104 | # Zero diagonal of sims 105 | sims.fill_diagonal_(0) 106 | 107 | encoder_vecs = ae.encoder.weight 108 | encoder_vecs = encoder_vecs / encoder_vecs.norm(dim=-1, keepdim=True) 109 | encoder_sims = encoder_vecs @ encoder_vecs.T 110 | 111 | with open(output_file, 'w') as f: 112 | f.write(f"--------------- NUM EXPERTS = {num_experts}, K = {k} ---------------\n\n\n") 113 | for feature_index, (top_activating, num_dupes) in feature_index_to_topk_activating.items(): 114 | similar = torch.nonzero(sims[feature_index] > threshold).flatten().cpu().numpy() 115 | nonzero_sims = sims[feature_index][similar].cpu().numpy() 116 | this_encoder_sims = encoder_sims[feature_index][similar].cpu().numpy() 117 | f.write(f"feature {feature_index} with {num_dupes} dupes: {similar}, {nonzero_sims}, {this_encoder_sims}\n") 118 | for context_index, token_index, value in top_activating: 119 | context_start = max(0, token_index - context_limit) 120 | context_end = min(token_index, ctx_len) 121 | token_context = tokens[context_index][context_start:context_end] 122 | token_context_strs = [tokenizer.decode(token) for token in token_context] 123 | context = "".join(token_context_strs) 124 | context = context.replace("\n", " ") 125 | if context_end + 1 < ctx_len: 126 | next_token = tokenizer.decode(tokens[context_index][context_end: context_end + 1][0]) 127 | next_token = next_token.replace("\n", "") 128 | else: 129 | next_token = "" 130 | f.write(f"{value:.4f} {next_token}\n{context}\n") 131 | f.write("\n") 132 | f.write("\n") 133 | # %% 134 | -------------------------------------------------------------------------------- /compare_geometry/find_duplicate_features.py: -------------------------------------------------------------------------------- 1 | # %% 2 | 3 | import os 4 | import sys 5 | sys.path.insert(1, os.path.join(sys.path[0], '..')) 6 | 7 | from dictionary_learning.trainers.switch import SwitchAutoEncoder 8 | from dictionary_learning.trainers.top_k import AutoEncoderTopK 9 | import einops 10 | import matplotlib.pyplot as plt 11 | import torch 12 | import numpy as np 13 | from tqdm import tqdm 14 | import os 15 | import pandas as pd 16 | 17 | 18 | os.makedirs("plots/compare_geometry", exist_ok=True) 19 | 20 | torch.set_grad_enabled(False) 21 | 22 | 23 | experts = [1, 16, 32, 64, 128] 24 | ks = [8, 16, 32, 48, 64, 96, 128, 192] 25 | 26 | device = "cuda:1" 27 | # device = "cpu" 28 | 29 | threshold = 0.9 30 | 31 | 32 | intra_sae_max_sims = {} 33 | 34 | fig, axs = plt.subplots(len(experts), len(ks), figsize=(20, 10)) 35 | 36 | data = [] 37 | 38 | for i, num_experts in enumerate(experts): 39 | 40 | for j, k in enumerate(tqdm(ks)): 41 | ax = axs[i, j] 42 | 43 | if num_experts == 1: 44 | ae = AutoEncoderTopK.from_pretrained(f"../dictionaries/topk/k{k}/ae.pt", k=k, device=device) 45 | else: 46 | ae = SwitchAutoEncoder.from_pretrained(f"../dictionaries/fixed-width/{num_experts}_experts/k{k}/ae.pt", k=k, experts=num_experts, device=device) 47 | 48 | normalized_weights = ae.decoder.data / ae.decoder.data.norm(dim=-1, keepdim=True) 49 | 50 | sims = normalized_weights @ normalized_weights.T 51 | 52 | sims.fill_diagonal_(0) 53 | 54 | num_dupes = (sims > threshold).sum(dim=-1).cpu().numpy() 55 | 56 | ax.hist([i for i in num_dupes if i != 0], bins=100) 57 | 58 | ax.set_title(f"{num_experts} experts, k={k}") 59 | 60 | for feature_index, num_dupes in enumerate(num_dupes): 61 | if num_dupes == 0: 62 | continue 63 | data.append((num_experts, k, feature_index, num_dupes)) 64 | 65 | fig.suptitle(f"Number of duplicates per feature, threshold={threshold}") 66 | 67 | fig.supxlabel("Number of duplicates") 68 | 69 | fig.supylabel("Number of features") 70 | 71 | plt.tight_layout() 72 | 73 | plt.savefig("plots/compare_geometry/num_duplicates_per_feature_fixed_width.pdf") 74 | 75 | df = pd.DataFrame(data, columns=["num_experts", "k", "feature_index", "num_dupes"]) 76 | 77 | # Sort by num_experts, then k, then num_dupes in descending order 78 | df = df.sort_values(by=["num_experts", "k", "num_dupes"], ascending=[True, True, False]) 79 | 80 | os.makedirs("../data", exist_ok=True) 81 | df.to_csv("../data/duplicates.csv", index=False) 82 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | lm = 'openai-community/gpt2' 2 | activation_dim = 768 3 | layer = 8 4 | hf = 'Skylion007/openwebtext' 5 | steps = 100_000 6 | n_ctxs = int(1e4) 7 | -------------------------------------------------------------------------------- /dictionary_learning/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/dictionary_learning/.DS_Store -------------------------------------------------------------------------------- /dictionary_learning/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | dictionaries 162 | wandb 163 | experiment* 164 | run_experiment.sh 165 | nohup.out -------------------------------------------------------------------------------- /dictionary_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder 2 | from .buffer import ActivationBuffer -------------------------------------------------------------------------------- /dictionary_learning/config.py: -------------------------------------------------------------------------------- 1 | # debugging flag for use in other scripts 2 | DEBUG = False -------------------------------------------------------------------------------- /dictionary_learning/evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for evaluating dictionaries on a model and dataset. 3 | """ 4 | 5 | import torch as t 6 | from .buffer import ActivationBuffer, NNsightActivationBuffer 7 | from nnsight import LanguageModel 8 | from .config import DEBUG 9 | 10 | 11 | def loss_recovered( 12 | text, # a batch of text 13 | model: LanguageModel, # an nnsight LanguageModel 14 | submodule, # submodules of model 15 | dictionary, # dictionaries for submodules 16 | max_len=None, # max context length for loss recovered 17 | normalize_batch=False, # normalize batch before passing through dictionary 18 | io="out", # can be 'in', 'out', or 'in_and_out' 19 | tracer_args = {'use_cache': False, 'output_attentions': False}, # minimize cache during model trace. 20 | ): 21 | """ 22 | How much of the model's loss is recovered by replacing the component output 23 | with the reconstruction by the autoencoder? 24 | """ 25 | 26 | if max_len is None: 27 | invoker_args = {} 28 | else: 29 | invoker_args = {"truncation": True, "max_length": max_len } 30 | 31 | # unmodified logits 32 | with model.trace(text, invoker_args=invoker_args): 33 | logits_original = model.output.save() 34 | logits_original = logits_original.value 35 | 36 | # logits when replacing component activations with reconstruction by autoencoder 37 | with model.trace(text, **tracer_args, invoker_args=invoker_args): 38 | if io == 'in': 39 | x = submodule.input[0] 40 | if type(submodule.input.shape) == tuple: x = x[0] 41 | if normalize_batch: 42 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 43 | x = x * scale 44 | elif io == 'out': 45 | x = submodule.output 46 | if type(submodule.output.shape) == tuple: x = x[0] 47 | if normalize_batch: 48 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 49 | x = x * scale 50 | elif io == 'in_and_out': 51 | x = submodule.input[0] 52 | if type(submodule.input.shape) == tuple: x = x[0] 53 | print(f'x.shape: {x.shape}') 54 | if normalize_batch: 55 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 56 | x = x * scale 57 | else: 58 | raise ValueError(f"Invalid value for io: {io}") 59 | x = x.save() 60 | 61 | # pull this out so dictionary can be written without FakeTensor (top_k needs this) 62 | x_hat = dictionary(x.view(-1, x.shape[-1])).view(x.shape) 63 | 64 | # intervene with `x_hat` 65 | with model.trace(text, **tracer_args, invoker_args=invoker_args): 66 | if io == 'in': 67 | x = submodule.input[0] 68 | if normalize_batch: 69 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 70 | x_hat = x_hat / scale 71 | if type(submodule.input.shape) == tuple: 72 | submodule.input[0][:] = x_hat 73 | else: 74 | submodule.input = x_hat 75 | elif io == 'out': 76 | x = submodule.output 77 | if normalize_batch: 78 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 79 | x_hat = x_hat / scale 80 | if type(submodule.output.shape) == tuple: 81 | submodule.output = (x_hat,) 82 | else: 83 | submodule.output = x_hat 84 | elif io == 'in_and_out': 85 | x = submodule.input[0] 86 | if normalize_batch: 87 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 88 | x_hat = x_hat / scale 89 | submodule.output = x_hat 90 | else: 91 | raise ValueError(f"Invalid value for io: {io}") 92 | 93 | logits_reconstructed = model.output.save() 94 | logits_reconstructed = logits_reconstructed.value 95 | 96 | # logits when replacing component activations with zeros 97 | with model.trace(text, **tracer_args, invoker_args=invoker_args): 98 | if io == 'in': 99 | x = submodule.input[0] 100 | if type(submodule.input.shape) == tuple: 101 | submodule.input[0][:] = t.zeros_like(x[0]) 102 | else: 103 | submodule.input = t.zeros_like(x) 104 | elif io in ['out', 'in_and_out']: 105 | x = submodule.output 106 | if type(submodule.output.shape) == tuple: 107 | submodule.output[0][:] = t.zeros_like(x[0]) 108 | else: 109 | submodule.output = t.zeros_like(x) 110 | else: 111 | raise ValueError(f"Invalid value for io: {io}") 112 | 113 | input = model.input.save() 114 | logits_zero = model.output.save() 115 | logits_zero = logits_zero.value 116 | 117 | # get everything into the right format 118 | try: 119 | logits_original = logits_original.logits 120 | logits_reconstructed = logits_reconstructed.logits 121 | logits_zero = logits_zero.logits 122 | except: 123 | pass 124 | 125 | if isinstance(text, t.Tensor): 126 | tokens = text 127 | else: 128 | try: 129 | tokens = input[1]['input_ids'] 130 | except: 131 | tokens = input[1]['input'] 132 | 133 | # compute losses 134 | losses = [] 135 | if hasattr(model, 'tokenizer') and model.tokenizer is not None: 136 | loss_kwargs = {'ignore_index': model.tokenizer.pad_token_id} 137 | else: 138 | loss_kwargs = {} 139 | for logits in [logits_original, logits_reconstructed, logits_zero]: 140 | loss = t.nn.CrossEntropyLoss(**loss_kwargs)( 141 | logits[:, :-1, :].reshape(-1, logits.shape[-1]), tokens[:, 1:].reshape(-1) 142 | ) 143 | losses.append(loss) 144 | 145 | return tuple(losses) 146 | 147 | 148 | def evaluate( 149 | dictionary, # a dictionary 150 | activations, # a generator of activations; if an ActivationBuffer, also compute loss recovered 151 | max_len=128, # max context length for loss recovered 152 | batch_size=128, # batch size for loss recovered 153 | io="out", # can be 'in', 'out', or 'in_and_out' 154 | normalize_batch=False, # normalize batch before passing through dictionary 155 | tracer_args={'use_cache': False, 'output_attentions': False}, # minimize cache during model trace. 156 | device="cpu", 157 | ): 158 | with t.no_grad(): 159 | 160 | out = {} # dict of results 161 | 162 | try: 163 | x = next(activations).to(device) 164 | if normalize_batch: 165 | x = x / x.norm(dim=-1).mean() * (dictionary.activation_dim ** 0.5) 166 | 167 | except StopIteration: 168 | raise StopIteration( 169 | "Not enough activations in buffer. Pass a buffer with a smaller batch size or more data." 170 | ) 171 | 172 | x_hat, f = dictionary(x, output_features=True) 173 | l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() 174 | mse_loss = (x - x_hat).pow(2).sum(dim=-1).mean() 175 | l1_loss = f.norm(p=1, dim=-1).mean() 176 | l0 = (f != 0).float().sum(dim=-1).mean() 177 | frac_alive = t.flatten(f, start_dim=0, end_dim=1).any(dim=0).sum() / dictionary.dict_size 178 | 179 | # cosine similarity between x and x_hat 180 | x_normed = x / t.linalg.norm(x, dim=-1, keepdim=True) 181 | x_hat_normed = x_hat / t.linalg.norm(x_hat, dim=-1, keepdim=True) 182 | cossim = (x_normed * x_hat_normed).sum(dim=-1).mean() 183 | 184 | # l2 ratio 185 | l2_ratio = (t.linalg.norm(x_hat, dim=-1) / t.linalg.norm(x, dim=-1)).mean() 186 | 187 | #compute variance explained 188 | total_variance = t.var(x, dim=0).sum() 189 | residual_variance = t.var(x - x_hat, dim=0).sum() 190 | frac_variance_explained = (1 - residual_variance / total_variance) 191 | 192 | out["l2_loss"] = l2_loss.item() 193 | out["l1_loss"] = l1_loss.item() 194 | out["mse_loss"] = mse_loss.item() 195 | out["l0"] = l0.item() 196 | out["frac_alive"] = frac_alive.item() 197 | out["frac_variance_explained"] = frac_variance_explained.item() 198 | out["cossim"] = cossim.item() 199 | out["l2_ratio"] = l2_ratio.item() 200 | 201 | if not isinstance(activations, (ActivationBuffer, NNsightActivationBuffer)): 202 | return out 203 | 204 | # compute loss recovered 205 | loss_original, loss_reconstructed, loss_zero = loss_recovered( 206 | activations.text_batch(batch_size=batch_size), 207 | activations.model, 208 | activations.submodule, 209 | dictionary, 210 | max_len=max_len, 211 | normalize_batch=normalize_batch, 212 | io=io, 213 | tracer_args=tracer_args 214 | ) 215 | frac_recovered = (loss_reconstructed - loss_zero) / (loss_original - loss_zero) 216 | 217 | out["loss_original"] = loss_original.item() 218 | out["loss_reconstructed"] = loss_reconstructed.item() 219 | out["loss_zero"] = loss_zero.item() 220 | out["frac_recovered"] = frac_recovered.item() 221 | 222 | return out -------------------------------------------------------------------------------- /dictionary_learning/grad_pursuit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements batched gradient pursuit algorithm here: 3 | https://www.lesswrong.com/posts/C5KAZQib3bzzpeyrg/full-post-progress-update-1-from-the-gdm-mech-interp-team#Inference_Time_Optimisation:~:text=two%20seem%20promising.-,Details%20of%20Sparse%20Approximation%20Algorithms%20(for%20accelerators),-This%20section%20gets 4 | """ 5 | 6 | import torch as t 7 | 8 | 9 | def _grad_pursuit_update_step(signal, weights, dictionary, batch_arange, selected_features): 10 | """ 11 | signal: b x d, weights: b x n, dictionary: d x n, batch_arange: b, selected_features: b x n 12 | """ 13 | residual = signal - t.einsum('bn,dn -> bd', weights, dictionary) 14 | # choose the element with largest inner product with residual, as in matched pursuit. 15 | inner_products = t.einsum('dn,bd -> bn', dictionary, residual) 16 | idxs = t.argmax(inner_products, dim=1) 17 | # add the new feature to the active set. 18 | selected_features[batch_arange, idxs] = 1 19 | 20 | # the gradient for the weights is the inner product, restricted to the chosen features 21 | grad = selected_features * inner_products 22 | # the next two steps compute the optimal step size 23 | c = t.einsum('bn,dn -> bd', grad, dictionary) 24 | step_size = t.einsum('bd,bd -> b', c, residual) / t.einsum('bd,bd -> b ', c, c) 25 | weights = weights + t.einsum('b,bn -> bn', step_size, grad) 26 | weights = t.clip(weights, min=0) # clip the weights to be positive 27 | return weights, selected_features 28 | 29 | def grad_pursuit(signal, dictionary, target_l0 : int = 20, device : str = 'cpu'): 30 | """ 31 | Inputs: signal: b x d, dictionary: d x n, target_l0: int, device: str 32 | Outputs: weights: b x n 33 | """ 34 | assert len(signal.shape) == 2 # makes sure this a batch of signals 35 | with t.no_grad(): 36 | batch_arange = t.arange(signal.shape[0]).to(device) 37 | weights = t.zeros((signal.shape[0], dictionary.shape[1])).to(device) 38 | selected_features = t.zeros((signal.shape[0], dictionary.shape[1])).to(device) 39 | for _ in range(target_l0): 40 | weights, selected_features = _grad_pursuit_update_step( 41 | signal, weights, dictionary, batch_arange, selected_features) 42 | return weights -------------------------------------------------------------------------------- /dictionary_learning/interp.py: -------------------------------------------------------------------------------- 1 | import random 2 | from circuitsvis.activations import text_neuron_activations 3 | from einops import rearrange 4 | import torch as t 5 | from collections import namedtuple 6 | import umap 7 | import pandas as pd 8 | import plotly.express as px 9 | 10 | TRACER_KWARGS = {"scan" : False, "validate" : False} 11 | 12 | def feature_effect( 13 | model, 14 | submodule, 15 | dictionary, 16 | feature, 17 | inputs, 18 | max_length=128, 19 | add_residual=True, # whether to compensate for dictionary reconstruction error by adding residual 20 | k=10, 21 | largest=True, 22 | ): 23 | """ 24 | Effect of ablating the feature on top k predictions for next token. 25 | """ 26 | # clean run 27 | with t.no_grad(), model.trace(inputs, invoker_args=dict(max_length=max_length, truncation=True)): 28 | if dictionary is None: 29 | pass 30 | elif not add_residual: # run hidden state through autoencoder 31 | if type(submodule.output.shape) == tuple: 32 | submodule.output[0][:] = dictionary(submodule.output[0]) 33 | else: 34 | submodule.output = dictionary(submodule.output) 35 | clean_output = model.output.save() 36 | try: 37 | clean_logits = clean_output.value.logits[:, -1, :] 38 | except: 39 | clean_logits = clean_output.value[:, -1, :] 40 | clean_logprobs = t.nn.functional.log_softmax(clean_logits, dim=-1) 41 | 42 | # ablated run 43 | with t.no_grad(), model.trace(inputs, invoker_args=dict(max_length=max_length, truncation=True)): 44 | if dictionary is None: 45 | if type(submodule.output.shape) == tuple: 46 | submodule.output[0][:, -1, feature] = 0 47 | else: 48 | submodule.output[:, -1, feature] = 0 49 | else: 50 | x = submodule.output 51 | if type(x.shape) == tuple: 52 | x = x[0] 53 | x_hat, f = dictionary(x, output_features=True) 54 | residual = x - x_hat 55 | 56 | f[:, -1, feature] = 0 57 | if add_residual: 58 | x_hat = dictionary.decode(f) + residual 59 | else: 60 | x_hat = dictionary.decode(f) 61 | 62 | if type(submodule.output.shape) == tuple: 63 | submodule.output[0][:] = x_hat 64 | else: 65 | submodule.output = x_hat 66 | ablated_output = model.output.save() 67 | try: 68 | ablated_logits = ablated_output.value.logits[:, -1, :] 69 | except: 70 | ablated_logits = ablated_output.value[:, -1, :] 71 | ablated_logprobs = t.nn.functional.log_softmax(ablated_logits, dim=-1) 72 | 73 | diff = clean_logprobs - ablated_logprobs 74 | top_probs, top_tokens = t.topk(diff.mean(dim=0), k=k, largest=largest) 75 | return top_tokens, top_probs 76 | 77 | 78 | def examine_dimension(model, submodule, buffer, dictionary=None, max_length=128, n_inputs=512, 79 | dim_idx=None, k=30): 80 | 81 | def _list_decode(x): 82 | if isinstance(x, int): 83 | return model.tokenizer.decode(x) 84 | else: 85 | return [_list_decode(y) for y in x] 86 | 87 | if dim_idx is None: 88 | dim_idx = random.randint(0, activations.shape[-1]-1) 89 | 90 | inputs = buffer.text_batch(batch_size=n_inputs) 91 | with t.no_grad(), model.trace(inputs, invoker_args=dict(max_length=max_length, truncation=True)): 92 | tokens = model.input[1]['input_ids'].save() # if you're getting errors, check here; might only work for pythia models 93 | activations = submodule.output 94 | if type(activations.shape) == tuple: 95 | activations = activations[0] 96 | if dictionary is not None: 97 | activations = dictionary.encode(activations) 98 | activations = activations[:,:, dim_idx].save() 99 | activations = activations.value 100 | 101 | # get top k tokens by mean activation 102 | tokens = tokens.value 103 | token_mean_acts = {} 104 | for ctx in tokens: 105 | for tok in ctx: 106 | if tok.item() in token_mean_acts: 107 | continue 108 | idxs = (tokens == tok).nonzero(as_tuple=True) 109 | token_mean_acts[tok.item()] = activations[idxs].mean().item() 110 | top_tokens = sorted(token_mean_acts.items(), key=lambda x: x[1], reverse=True)[:k] 111 | top_tokens = [(model.tokenizer.decode(tok), act) for tok, act in top_tokens] 112 | 113 | flattened_acts = rearrange(activations, 'b n -> (b n)') 114 | topk_indices = t.argsort(flattened_acts, dim=0, descending=True)[:k] 115 | batch_indices = topk_indices // activations.shape[1] 116 | token_indices = topk_indices % activations.shape[1] 117 | tokens = [ 118 | tokens[batch_idx, :token_idx+1].tolist() for batch_idx, token_idx in zip(batch_indices, token_indices) 119 | ] 120 | activations = [ 121 | activations[batch_idx, :token_id+1, None, None] for batch_idx, token_id in zip(batch_indices, token_indices) 122 | ] 123 | decoded_tokens = _list_decode(tokens) 124 | top_contexts = text_neuron_activations(decoded_tokens, activations) 125 | 126 | top_affected = feature_effect( 127 | model, 128 | submodule, 129 | dictionary, 130 | dim_idx, 131 | tokens, 132 | max_length=max_length, 133 | k=k 134 | ) 135 | top_affected = [(model.tokenizer.decode(tok), prob.item()) for tok, prob in zip(*top_affected)] 136 | 137 | return namedtuple('featureProfile', ['top_contexts', 'top_tokens', 'top_affected'])(top_contexts, top_tokens, top_affected) 138 | 139 | def feature_umap( 140 | dictionary, 141 | weight='decoder', # 'encoder' or 'decoder' 142 | # UMAP parameters 143 | n_neighbors=15, 144 | metric='cosine', 145 | min_dist=0.05, 146 | n_components=2, # dimension of the UMAP embedding 147 | feat_idxs=None, # if not none, indicate the feature with a red dot 148 | ): 149 | """ 150 | Fit a UMAP embedding of the dictionary features and return a plotly plot of the result.""" 151 | if weight == 'encoder': 152 | df = pd.DataFrame(dictionary.encoder.weight.cpu().detach().numpy()) 153 | else: 154 | df = pd.DataFrame(dictionary.decoder.weight.T.cpu().detach().numpy()) 155 | reducer = umap.UMAP( 156 | n_neighbors=n_neighbors, 157 | metric=metric, 158 | min_dist=min_dist, 159 | n_components=n_components, 160 | ) 161 | embedding = reducer.fit_transform(df) 162 | if feat_idxs is None: 163 | colors = None 164 | if isinstance(feat_idxs, int): 165 | feat_idxs = [feat_idxs] 166 | else: 167 | colors = ['blue' if i not in feat_idxs else 'red' for i in range(embedding.shape[0])] 168 | if n_components == 2: 169 | return px.scatter(x=embedding[:, 0], y=embedding[:, 1], hover_name=df.index, color=colors) 170 | if n_components == 3: 171 | return px.scatter_3d(x=embedding[:, 0], y=embedding[:, 1], z=embedding[:, 2], hover_name=df.index, color=colors) 172 | raise ValueError("n_components must be 2 or 3") -------------------------------------------------------------------------------- /dictionary_learning/pretrained_dictionary_downloader.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Base URL for downloads - replace this with your actual URL, excluding the /aX part 4 | BASE_URL="https://baulab.us/u/smarks/autoencoders/pythia-70m-deduped" 5 | 6 | # Local directory where you want to replicate the folder structure, now including the specified root directory 7 | LOCAL_DIR="dictionaries/pythia-70m-deduped" 8 | 9 | # Default 'a' values array 10 | declare -a default_a_values=("attn_out_layerX" "mlp_out_layerX" "resid_out_layerX") # Removed "embed" from default handling 11 | 12 | # 'c' values array - Initially without "checkpoints" 13 | declare -a c_values=("ae.pt" "config.json") 14 | 15 | # Name of the set of autoencoders 16 | sae_set_name="10_32768" 17 | 18 | # Checkpoints flag variable, default is 0 (don't download checkpoints) 19 | download_checkpoints=0 20 | 21 | # Custom layers to download 22 | declare -a custom_layers=("0" "1" "2" "3" "4" "5" "embed") 23 | 24 | # Parse flags 25 | while [[ "$#" -gt 0 ]]; do 26 | case $1 in 27 | -c|--checkpoints) download_checkpoints=1 ;; 28 | --layers) IFS=',' read -ra custom_layers <<< "$2"; shift ;; 29 | *) echo "Unknown parameter passed: $1"; exit 1 ;; 30 | esac 31 | shift 32 | done 33 | 34 | # Ensure the root download directory exists 35 | mkdir -p "${LOCAL_DIR}" 36 | 37 | # Include checkpoints if flag is set 38 | if [[ $download_checkpoints -eq 1 ]]; then 39 | c_values+=("checkpoints") 40 | fi 41 | 42 | # Prepare 'a' values based on custom_layers input 43 | declare -a a_values=() 44 | if [[ ${#custom_layers[@]} -eq 0 ]]; then 45 | a_values=("${default_a_values[@]}") 46 | else 47 | for layer in "${custom_layers[@]}"; do 48 | if [[ $layer == "embed" ]]; then 49 | a_values+=("embed") 50 | else 51 | for a_value in "${default_a_values[@]}"; do 52 | a_values+=("${a_value/X/$layer}") 53 | done 54 | fi 55 | done 56 | fi 57 | 58 | # Download logic 59 | for a_value in "${a_values[@]}"; do 60 | for c in "${c_values[@]}"; do 61 | DOWNLOAD_URL="${BASE_URL}/${a_value}/${sae_set_name}/${c}" 62 | LOCAL_PATH="${LOCAL_DIR}/${a_value}/${sae_set_name}/${c}" 63 | if [ "${c}" == "checkpoints" ]; then 64 | # Special handling for downloading checkpoints as folders 65 | mkdir -p "${LOCAL_PATH}" 66 | wget -r -np -nH --cut-dirs=7 -P "${LOCAL_PATH}" --accept "*.pt" "${DOWNLOAD_URL}/" 67 | 68 | else 69 | # Handle all other files 70 | mkdir -p "$(dirname "${LOCAL_PATH}")" 71 | wget -P "$(dirname "${LOCAL_PATH}")" "${DOWNLOAD_URL}" 72 | fi 73 | done 74 | done 75 | 76 | echo "Download completed." 77 | -------------------------------------------------------------------------------- /dictionary_learning/requirements.txt: -------------------------------------------------------------------------------- 1 | circuitsvis>=1.43.2 2 | datasets>=2.18.0 3 | einops>=0.7.0 4 | matplotlib>=3.8.3 5 | nnsight>=0.2.11 6 | pandas>=2.2.1 7 | plotly>=5.18.0 8 | torch>=2.1.2 9 | tqdm>=4.66.1 10 | umap-learn>=0.5.6 11 | zstandard>=0.22.0 12 | wandb 13 | -------------------------------------------------------------------------------- /dictionary_learning/trainers/gdm.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the training scheme for a gated SAE described in https://arxiv.org/abs/2404.16014 3 | """ 4 | 5 | import torch as t 6 | from ..trainers.trainer import SAETrainer 7 | from ..config import DEBUG 8 | from ..dictionary import GatedAutoEncoder 9 | from collections import namedtuple 10 | 11 | class ConstrainedAdam(t.optim.Adam): 12 | """ 13 | A variant of Adam where some of the parameters are constrained to have unit norm. 14 | """ 15 | def __init__(self, params, constrained_params, lr): 16 | super().__init__(params, lr=lr, betas=(0, 0.999)) 17 | self.constrained_params = list(constrained_params) 18 | 19 | def step(self, closure=None): 20 | with t.no_grad(): 21 | for p in self.constrained_params: 22 | normed_p = p / p.norm(dim=0, keepdim=True) 23 | # project away the parallel component of the gradient 24 | p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p 25 | super().step(closure=closure) 26 | with t.no_grad(): 27 | for p in self.constrained_params: 28 | # renormalize the constrained parameters 29 | p /= p.norm(dim=0, keepdim=True) 30 | 31 | class GatedSAETrainer(SAETrainer): 32 | """ 33 | Gated SAE training scheme. 34 | """ 35 | def __init__(self, 36 | dict_class=GatedAutoEncoder, 37 | activation_dim=512, 38 | dict_size=64*512, 39 | lr=5e-5, 40 | l1_penalty=1e-1, 41 | warmup_steps=1000, # lr warmup period at start of training and after each resample 42 | resample_steps=None, # how often to resample neurons 43 | seed=None, 44 | device=None, 45 | layer=None, 46 | lm_name=None, 47 | wandb_name='GatedSAETrainer', 48 | ): 49 | super().__init__(seed) 50 | 51 | assert layer is not None and lm_name is not None 52 | self.layer = layer 53 | self.lm_name = lm_name 54 | 55 | if seed is not None: 56 | t.manual_seed(seed) 57 | t.cuda.manual_seed_all(seed) 58 | 59 | # initialize dictionary 60 | self.ae = dict_class(activation_dim, dict_size) 61 | 62 | self.lr = lr 63 | self.l1_penalty=l1_penalty 64 | self.warmup_steps = warmup_steps 65 | self.wandb_name = wandb_name 66 | 67 | if device is None: 68 | self.device = 'cuda' if t.cuda.is_available() else 'cpu' 69 | else: 70 | self.device = device 71 | self.ae.to(self.device) 72 | 73 | self.optimizer = ConstrainedAdam( 74 | self.ae.parameters(), 75 | self.ae.decoder.parameters(), 76 | lr=lr 77 | ) 78 | def warmup_fn(step): 79 | return min(1, step / warmup_steps) 80 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, warmup_fn) 81 | 82 | def loss(self, x, logging=False, **kwargs): 83 | f, f_gate = self.ae.encode(x, return_gate=True) 84 | x_hat = self.ae.decode(f) 85 | x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach() 86 | 87 | L_recon = (x - x_hat).pow(2).sum(dim=-1).mean() 88 | L_sparse = t.linalg.norm(f_gate, ord=1, dim=-1).mean() 89 | L_aux = (x - x_hat_gate).pow(2).sum(dim=-1).mean() 90 | 91 | loss = L_recon + self.l1_penalty * L_sparse + L_aux 92 | 93 | if not logging: 94 | return loss 95 | else: 96 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( 97 | x, x_hat, f, 98 | { 99 | 'mse_loss' : L_recon.item(), 100 | 'sparsity_loss' : L_sparse.item(), 101 | 'aux_loss' : L_aux.item(), 102 | 'loss' : loss.item() 103 | } 104 | ) 105 | 106 | def update(self, step, x): 107 | x = x.to(self.device) 108 | self.optimizer.zero_grad() 109 | loss = self.loss(x) 110 | loss.backward() 111 | self.optimizer.step() 112 | self.scheduler.step() 113 | 114 | @property 115 | def config(self): 116 | return { 117 | 'trainer_class' : 'GatedSAETrainer', 118 | 'activation_dim' : self.ae.activation_dim, 119 | 'dict_size' : self.ae.dict_size, 120 | 'lr' : self.lr, 121 | 'l1_penalty' : self.l1_penalty, 122 | 'warmup_steps' : self.warmup_steps, 123 | 'device' : self.device, 124 | 'layer' : self.layer, 125 | 'lm_name' : self.lm_name, 126 | 'wandb_name': self.wandb_name, 127 | } 128 | -------------------------------------------------------------------------------- /dictionary_learning/trainers/jump_relu.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import torch.autograd as autograd 3 | from ..trainers.trainer import SAETrainer 4 | from ..dictionary import JumpReluAutoEncoder, StepFunction 5 | 6 | from ..config import DEBUG 7 | from collections import namedtuple 8 | 9 | class JumpReluTrainer(SAETrainer): 10 | def __init__(self, 11 | dict_class=JumpReluAutoEncoder, 12 | activation_dim=512, 13 | dict_size=64*512, 14 | lr=5e-5, 15 | l0_penalty=1e-1, 16 | warmup_steps=1000, # lr warmup period at start of training and after each resample 17 | seed=None, 18 | device=None, 19 | layer=None, 20 | lm_name=None, 21 | wandb_name='JumpReluTrainer', 22 | submodule_name=None, 23 | set_linear_to_constant=False, 24 | pre_encoder_bias=True, 25 | ): 26 | super().__init__(seed) 27 | if seed is not None: 28 | t.manual_seed(seed) 29 | t.cuda.manual_seed_all(seed) 30 | 31 | assert layer is not None and lm_name is not None 32 | self.layer = layer 33 | self.lm_name = lm_name 34 | self.submodule_name = submodule_name 35 | self.wandb_name = wandb_name 36 | 37 | if device is None: 38 | self.device = 'cuda' if t.cuda.is_available() else 'cpu' 39 | else: 40 | self.device = device 41 | 42 | self.ae = dict_class(activation_dim, dict_size, pre_encoder_bias=pre_encoder_bias) 43 | self.ae.to(self.device) 44 | self.lr = lr 45 | self.warmup_steps = warmup_steps 46 | self.l0_penalty = l0_penalty 47 | self.set_linear_to_constant = set_linear_to_constant 48 | 49 | self.optimizer = t.optim.Adam(self.ae.parameters(), betas=(0, 0.999), eps=1e-8) 50 | def warmup_fn(step): 51 | return min(1, step / warmup_steps) 52 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, warmup_fn) 53 | 54 | def loss(self, x, logging=False, **kwargs): 55 | x_hat, f = self.ae(x, output_features=True, set_linear_to_constant=self.set_linear_to_constant) 56 | L_recon = (x - x_hat).pow(2).sum(dim=-1).mean() 57 | L_spars = StepFunction.apply(f, self.ae.threshold, self.ae.bandwidth).sum(dim=-1).mean() 58 | 59 | loss = L_recon + self.l0_penalty * L_spars 60 | 61 | if not logging: 62 | return loss 63 | else: 64 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( 65 | x, x_hat, f, 66 | losses={ 67 | 'mse_loss' : L_recon.item(), 68 | 'sparsity_loss' : L_spars.item(), 69 | 'loss' : loss.item() 70 | } 71 | ) 72 | 73 | def update(self, step, x): 74 | x = x.to(self.device) 75 | self.optimizer.zero_grad() 76 | loss = self.loss(x) 77 | loss.backward() 78 | self.optimizer.step() 79 | 80 | @property 81 | def config(self): 82 | return { 83 | 'dict_class': 'JumpReluAutoEncoder', 84 | 'trainer_class': 'JumpReluTrainer', 85 | 'activation_dim': self.ae.activation_dim, 86 | 'dict_size': self.ae.dict_size, 87 | 'lr': self.lr, 88 | 'warmup_steps': self.warmup_steps, 89 | 'l0_penalty': self.l0_penalty, 90 | 'device': self.device, 91 | 'layer': self.layer, 92 | 'lm_name': self.lm_name, 93 | 'submodule_name': self.submodule_name, 94 | 'wandb_name': self.wandb_name, 95 | } -------------------------------------------------------------------------------- /dictionary_learning/trainers/standard.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the standard SAE training scheme. 3 | """ 4 | import torch as t 5 | from ..trainers.trainer import SAETrainer 6 | from ..config import DEBUG 7 | from ..dictionary import AutoEncoder 8 | from collections import namedtuple 9 | 10 | class ConstrainedAdam(t.optim.Adam): 11 | """ 12 | A variant of Adam where some of the parameters are constrained to have unit norm. 13 | """ 14 | def __init__(self, params, constrained_params, lr): 15 | super().__init__(params, lr=lr) 16 | self.constrained_params = list(constrained_params) 17 | 18 | def step(self, closure=None): 19 | with t.no_grad(): 20 | for p in self.constrained_params: 21 | normed_p = p / p.norm(dim=0, keepdim=True) 22 | # project away the parallel component of the gradient 23 | p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p 24 | super().step(closure=closure) 25 | with t.no_grad(): 26 | for p in self.constrained_params: 27 | # renormalize the constrained parameters 28 | p /= p.norm(dim=0, keepdim=True) 29 | 30 | class StandardTrainer(SAETrainer): 31 | """ 32 | Standard SAE training scheme. 33 | """ 34 | def __init__(self, 35 | dict_class=AutoEncoder, 36 | activation_dim=512, 37 | dict_size=64*512, 38 | lr=1e-3, 39 | l1_penalty=1e-1, 40 | warmup_steps=1000, # lr warmup period at start of training and after each resample 41 | resample_steps=None, # how often to resample neurons 42 | seed=None, 43 | device=None, 44 | layer=None, 45 | lm_name=None, 46 | wandb_name='StandardTrainer', 47 | ): 48 | super().__init__(seed) 49 | 50 | assert layer is not None and lm_name is not None 51 | self.layer = layer 52 | self.lm_name = lm_name 53 | 54 | if seed is not None: 55 | t.manual_seed(seed) 56 | t.cuda.manual_seed_all(seed) 57 | 58 | # initialize dictionary 59 | self.ae = dict_class(activation_dim, dict_size) 60 | 61 | self.lr = lr 62 | self.l1_penalty=l1_penalty 63 | self.warmup_steps = warmup_steps 64 | self.wandb_name = wandb_name 65 | 66 | if device is None: 67 | self.device = 'cuda' if t.cuda.is_available() else 'cpu' 68 | else: 69 | self.device = device 70 | self.ae.to(self.device) 71 | 72 | self.resample_steps = resample_steps 73 | 74 | 75 | if self.resample_steps is not None: 76 | # how many steps since each neuron was last activated? 77 | self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device) 78 | else: 79 | self.steps_since_active = None 80 | 81 | self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr) 82 | if resample_steps is None: 83 | def warmup_fn(step): 84 | return min(step / warmup_steps, 1.) 85 | else: 86 | def warmup_fn(step): 87 | return min((step % resample_steps) / warmup_steps, 1.) 88 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn) 89 | 90 | def resample_neurons(self, deads, activations): 91 | with t.no_grad(): 92 | if deads.sum() == 0: return 93 | print(f"resampling {deads.sum().item()} neurons") 94 | 95 | # compute loss for each activation 96 | losses = (activations - self.ae(activations)).norm(dim=-1) 97 | 98 | # sample input to create encoder/decoder weights from 99 | n_resample = min([deads.sum(), losses.shape[0]]) 100 | indices = t.multinomial(losses, num_samples=n_resample, replacement=False) 101 | sampled_vecs = activations[indices] 102 | 103 | # get norm of the living neurons 104 | alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean() 105 | 106 | # resample first n_resample dead neurons 107 | deads[deads.nonzero()[n_resample:]] = False 108 | self.ae.encoder.weight[deads] = sampled_vecs * alive_norm * 0.2 109 | self.ae.decoder.weight[:,deads] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T 110 | self.ae.encoder.bias[deads] = 0. 111 | 112 | 113 | # reset Adam parameters for dead neurons 114 | state_dict = self.optimizer.state_dict()['state'] 115 | ## encoder weight 116 | state_dict[1]['exp_avg'][deads] = 0. 117 | state_dict[1]['exp_avg_sq'][deads] = 0. 118 | ## encoder bias 119 | state_dict[2]['exp_avg'][deads] = 0. 120 | state_dict[2]['exp_avg_sq'][deads] = 0. 121 | ## decoder weight 122 | state_dict[3]['exp_avg'][:,deads] = 0. 123 | state_dict[3]['exp_avg_sq'][:,deads] = 0. 124 | 125 | def loss(self, x, logging=False, **kwargs): 126 | x_hat, f = self.ae(x, output_features=True) 127 | l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() 128 | l1_loss = f.norm(p=1, dim=-1).mean() 129 | 130 | if self.steps_since_active is not None: 131 | # update steps_since_active 132 | deads = (f == 0).all(dim=0) 133 | self.steps_since_active[deads] += 1 134 | self.steps_since_active[~deads] = 0 135 | 136 | loss = l2_loss + self.l1_penalty * l1_loss 137 | 138 | if not logging: 139 | return loss 140 | else: 141 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( 142 | x, x_hat, f, 143 | { 144 | 'l2_loss' : l2_loss.item(), 145 | 'mse_loss' : (x - x_hat).pow(2).sum(dim=-1).mean().item(), 146 | 'sparsity_loss' : l1_loss.item(), 147 | 'loss' : loss.item() 148 | } 149 | ) 150 | 151 | 152 | def update(self, step, activations): 153 | activations = activations.to(self.device) 154 | 155 | self.optimizer.zero_grad() 156 | loss = self.loss(activations) 157 | loss.backward() 158 | self.optimizer.step() 159 | self.scheduler.step() 160 | 161 | if self.resample_steps is not None and step % self.resample_steps == 0: 162 | self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations) 163 | 164 | @property 165 | def config(self): 166 | return { 167 | 'trainer_class' : 'StandardTrainer', 168 | 'lr' : self.lr, 169 | 'l1_penalty' : self.l1_penalty, 170 | 'warmup_steps' : self.warmup_steps, 171 | 'resample_steps' : self.resample_steps, 172 | 'device' : self.device, 173 | 'layer' : self.layer, 174 | 'lm_name' : self.lm_name, 175 | 'wandb_name': self.wandb_name, 176 | } 177 | -------------------------------------------------------------------------------- /dictionary_learning/trainers/standard_new.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the SAE training scheme from https://transformer-circuits.pub/2024/april-update/index.html#training-saes 3 | """ 4 | import torch as t 5 | from ..trainers.trainer import SAETrainer 6 | from ..config import DEBUG 7 | from ..dictionary import AutoEncoderNew 8 | from collections import namedtuple 9 | 10 | class StandardTrainerNew(SAETrainer): 11 | """ 12 | Standard SAE training scheme. 13 | """ 14 | def __init__(self, 15 | dict_class=AutoEncoderNew, 16 | activation_dim=512, 17 | dict_size=64*512, 18 | lr=5e-5, 19 | l1_penalty=1e-1, 20 | lambda_warm_steps=1500, # steps over which to warm up the l1 penalty 21 | decay_start=24000, # when does the lr decay start 22 | steps=30000, # when when does training end 23 | seed=None, 24 | device=None, 25 | wandb_name='StandardTrainerNew_Anthropic', 26 | ): 27 | super().__init__(seed) 28 | 29 | if seed is not None: 30 | t.manual_seed(seed) 31 | t.cuda.manual_seed_all(seed) 32 | 33 | # initialize dictionary 34 | self.ae = dict_class(activation_dim, dict_size) 35 | 36 | self.lr = lr 37 | self.l1_penalty=l1_penalty 38 | self.lambda_warm_steps=lambda_warm_steps 39 | self.decay_start=decay_start 40 | self.steps = steps 41 | self.wandb_name = wandb_name 42 | 43 | if device is None: 44 | self.device = 'cuda' if t.cuda.is_available() else 'cpu' 45 | else: 46 | self.device = device 47 | self.ae.to(self.device) 48 | 49 | self.optimizer = t.optim.Adam(self.ae.parameters(), lr=lr) 50 | def lr_fn(step): 51 | if step < decay_start: 52 | return 1. 53 | else: 54 | return (steps - step) / (steps - decay_start) 55 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) 56 | 57 | def loss(self, x, step=None, logging=False): 58 | # NOTE: not using normalization 59 | # if logging: 60 | # x = x / x.norm(dim=-1).mean() * (self.ae.activation_dim ** 0.5) 61 | 62 | x_hat, f = self.ae(x, output_features=True) 63 | 64 | l1_penalty = self.l1_penalty 65 | l1_penalty = min(1., step / self.lambda_warm_steps) * self.l1_penalty 66 | 67 | L_recon = (x - x_hat).pow(2).sum(dim=-1).mean() 68 | L_sparse = f.norm(p=1, dim=-1).mean() 69 | 70 | loss = L_recon + l1_penalty * L_sparse 71 | 72 | if not logging: 73 | return loss 74 | else: 75 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( 76 | x, x_hat, f, 77 | { 78 | 'mse_loss' : L_recon.item(), 79 | 'sparsity_loss' : L_sparse.item(), 80 | 'l1_penalty' : l1_penalty, 81 | 'loss' : loss.item() 82 | } 83 | ) 84 | 85 | 86 | def update(self, step, x): 87 | x = x.to(self.device) 88 | 89 | # NOTE: not using normalization 90 | # normalization 91 | # x = x / x.norm(dim=-1).mean() * (self.ae.activation_dim ** 0.5) 92 | 93 | self.optimizer.zero_grad() 94 | loss = self.loss(x, step=step) 95 | loss.backward() 96 | 97 | # clip grad norm 98 | t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) 99 | 100 | self.optimizer.step() 101 | self.scheduler.step() 102 | 103 | return loss.item() 104 | 105 | @property 106 | def config(self): 107 | return { 108 | 'trainer_class' : 'StandardTrainerNew', 109 | 'dict_class' : 'AutoEncoderNew', 110 | 'lr' : self.lr, 111 | 'l1_penalty' : self.l1_penalty, 112 | 'lambda_warm_steps' : self.lambda_warm_steps, 113 | 'decay_start' : self.decay_start, 114 | 'steps' : self.steps, 115 | 'seed' : self.seed, 116 | 'activation_dim' : self.ae.activation_dim, 117 | 'dict_size' : self.ae.dict_size, 118 | 'device' : self.device, 119 | 'wandb_name' : self.wandb_name, 120 | } -------------------------------------------------------------------------------- /dictionary_learning/trainers/top_k.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implements the SAE training scheme from https://arxiv.org/abs/2406.04093. 3 | Significant portions of this code have been copied from https://github.com/EleutherAI/sae/blob/main/sae 4 | """ 5 | 6 | import einops 7 | import torch as t 8 | import torch.nn as nn 9 | from collections import namedtuple 10 | 11 | from ..config import DEBUG 12 | from ..dictionary import Dictionary 13 | from ..kernels import TritonDecoder 14 | from ..trainers.trainer import SAETrainer 15 | 16 | 17 | @t.no_grad() 18 | def geometric_median(points: t.Tensor, max_iter: int = 100, tol: float = 1e-5): 19 | """Compute the geometric median `points`. Used for initializing decoder bias.""" 20 | # Initialize our guess as the mean of the points 21 | guess = points.mean(dim=0) 22 | prev = t.zeros_like(guess) 23 | 24 | # Weights for iteratively reweighted least squares 25 | weights = t.ones(len(points), device=points.device) 26 | 27 | for _ in range(max_iter): 28 | prev = guess 29 | 30 | # Compute the weights 31 | weights = 1 / t.norm(points - guess, dim=1) 32 | 33 | # Normalize the weights 34 | weights /= weights.sum() 35 | 36 | # Compute the new geometric median 37 | guess = (weights.unsqueeze(1) * points).sum(dim=0) 38 | 39 | # Early stopping condition 40 | if t.norm(guess - prev) < tol: 41 | break 42 | 43 | return guess 44 | 45 | 46 | class AutoEncoderTopK(Dictionary, nn.Module): 47 | """ 48 | The top-k autoencoder architecture and initialization used in https://arxiv.org/abs/2406.04093 49 | """ 50 | def __init__(self, activation_dim, dict_size, k): 51 | super().__init__() 52 | self.activation_dim = activation_dim 53 | self.dict_size = dict_size 54 | self.k = k 55 | 56 | self.encoder = nn.Linear(activation_dim, dict_size) 57 | self.encoder.bias.data.zero_() 58 | 59 | self.decoder = nn.Parameter(self.encoder.weight.data.clone()) 60 | self.set_decoder_norm_to_unit_norm() 61 | 62 | self.b_dec = self.b_dec = nn.Parameter(t.zeros(activation_dim)) 63 | 64 | def encode(self, x): 65 | return nn.functional.relu(self.encoder(x - self.b_dec)) 66 | 67 | def decode(self, top_acts, top_indices): 68 | d = TritonDecoder.apply(top_indices, top_acts, self.decoder.mT) 69 | return d + self.b_dec 70 | 71 | def forward(self, x, output_features=False): 72 | # (rangell): some shape hacking going on here 73 | f = self.encode(x.view(-1, x.shape[-1])) 74 | top_acts, top_indices = f.topk(self.k, sorted=False) 75 | x_hat = self.decode(top_acts, top_indices).view(x.shape) 76 | f = f.view(*x.shape[:-1], f.shape[-1]) 77 | if not output_features: 78 | return x_hat 79 | elif output_features == "all": 80 | return x_hat, f, top_acts, top_indices 81 | else: 82 | return x_hat, f 83 | 84 | @t.no_grad() 85 | def set_decoder_norm_to_unit_norm(self): 86 | eps = t.finfo(self.decoder.dtype).eps 87 | norm = t.norm(self.decoder.data, dim=1, keepdim=True) 88 | self.decoder.data /= norm + eps 89 | 90 | @t.no_grad() 91 | def remove_gradient_parallel_to_decoder_directions(self): 92 | assert self.decoder.grad is not None # keep pyright happy 93 | 94 | parallel_component = einops.einsum( 95 | self.decoder.grad, 96 | self.decoder.data, 97 | "d_sae d_in, d_sae d_in -> d_sae", 98 | ) 99 | self.decoder.grad -= einops.einsum( 100 | parallel_component, 101 | self.decoder.data, 102 | "d_sae, d_sae d_in -> d_sae d_in", 103 | ) 104 | 105 | def from_pretrained(path, k=100, device=None): 106 | """ 107 | Load a pretrained autoencoder from a file. 108 | """ 109 | state_dict = t.load(path, map_location=device) 110 | dict_size, activation_dim = state_dict['encoder.weight'].shape 111 | autoencoder = AutoEncoderTopK(activation_dim, dict_size, k) 112 | autoencoder.load_state_dict(state_dict) 113 | if device is not None: 114 | autoencoder.to(device) 115 | return autoencoder 116 | 117 | 118 | class TrainerTopK(SAETrainer): 119 | """ 120 | Top-K SAE training scheme. 121 | """ 122 | def __init__(self, 123 | dict_class=AutoEncoderTopK, 124 | activation_dim=512, 125 | dict_size=64*512, 126 | k=100, 127 | auxk_alpha=1/32, # see Appendix A.2 128 | decay_start=24000, # when does the lr decay start 129 | steps=30000, # when when does training end 130 | seed=None, 131 | device=None, 132 | layer=None, 133 | lm_name=None, 134 | wandb_name='AutoEncoderTopK', 135 | submodule_name=None, 136 | ): 137 | super().__init__(seed) 138 | 139 | assert layer is not None and lm_name is not None 140 | self.layer = layer 141 | self.lm_name = lm_name 142 | self.submodule_name = submodule_name 143 | 144 | self.wandb_name = wandb_name 145 | self.steps = steps 146 | self.k = k 147 | if seed is not None: 148 | t.manual_seed(seed) 149 | t.cuda.manual_seed_all(seed) 150 | 151 | # Initialise autoencoder 152 | self.ae = dict_class(activation_dim, dict_size, k) 153 | if device is None: 154 | self.device = 'cuda' if t.cuda.is_available() else 'cpu' 155 | else: 156 | self.device = device 157 | self.ae.to(self.device) 158 | 159 | # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper 160 | scale = dict_size / (2 ** 14) 161 | self.lr = 2e-4 / scale ** 0.5 162 | self.auxk_alpha = auxk_alpha 163 | self.dead_feature_threshold = 10_000_000 164 | 165 | # Optimizer and scheduler 166 | self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999)) 167 | def lr_fn(step): 168 | if step < decay_start: 169 | return 1. 170 | else: 171 | return (steps - step) / (steps - decay_start) 172 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn) 173 | 174 | # Training parameters 175 | self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device) 176 | 177 | # Log the effective L0, i.e. number of features actually used, which should a constant value (K) 178 | # Note: The standard L0 is essentially a measure of dead features for Top-K SAEs) 179 | self.logging_parameters = ["effective_l0", "dead_features"] 180 | self.effective_l0 = -1 181 | self.dead_features = -1 182 | 183 | def loss(self, x, step=None, logging=False): 184 | 185 | x = x.to(self.device) 186 | 187 | # Run the SAE 188 | f = self.ae.encode(x) 189 | top_acts, top_indices = f.topk(self.k, sorted=False) 190 | x_hat = self.ae.decode(top_acts, top_indices) 191 | 192 | # Measure goodness of reconstruction 193 | e = x_hat - x 194 | total_variance = (x - x.mean(0)).pow(2).sum(0) 195 | 196 | # Update the effective L0 (again, should just be K) 197 | self.effective_l0 = top_acts.size(1) 198 | 199 | # Update "number of tokens since fired" for each features 200 | num_tokens_in_step = x.size(0) 201 | did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool) 202 | did_fire[top_indices.flatten()] = True 203 | self.num_tokens_since_fired += num_tokens_in_step 204 | self.num_tokens_since_fired[did_fire] = 0 205 | 206 | # Compute dead feature mask based on "number of tokens since fired" 207 | dead_mask = ( 208 | self.num_tokens_since_fired > self.dead_feature_threshold 209 | if self.auxk_alpha > 0 210 | else None 211 | ).to(f.device) 212 | self.dead_features = int(dead_mask.sum()) 213 | 214 | # If dead features: Second decoder pass for AuxK loss 215 | if dead_mask is not None and (num_dead := int(dead_mask.sum())) > 0: 216 | 217 | # Heuristic from Appendix B.1 in the paper 218 | k_aux = x.shape[-1] // 2 219 | 220 | # Reduce the scale of the loss if there are a small number of dead latents 221 | scale = min(num_dead / k_aux, 1.0) 222 | k_aux = min(k_aux, num_dead) 223 | 224 | # Don't include living latents in this loss 225 | auxk_latents = t.where(dead_mask[None], f, -t.inf) 226 | 227 | # Top-k dead latents 228 | auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False) 229 | 230 | # Encourage the top ~50% of dead latents to predict the residual of the 231 | # top k living latents 232 | e_hat = self.ae.decode(auxk_acts, auxk_indices) 233 | auxk_loss = (e_hat - e).pow(2) #.sum(0) 234 | auxk_loss = scale * t.mean(auxk_loss / total_variance) 235 | else: 236 | auxk_loss = x_hat.new_tensor(0.0) 237 | 238 | l2_loss = e.pow(2).sum(dim=-1).mean() 239 | auxk_loss = auxk_loss.sum(dim=-1).mean() 240 | loss = l2_loss + self.auxk_alpha * auxk_loss 241 | 242 | if not logging: 243 | return loss 244 | else: 245 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])( 246 | x, x_hat, f, 247 | { 248 | 'l2_loss': l2_loss.item(), 249 | 'auxk_loss': auxk_loss.item(), 250 | 'loss' : loss.item() 251 | } 252 | ) 253 | 254 | def update(self, step, x): 255 | 256 | x = x.to(self.device) 257 | 258 | # Initialise the decoder bias 259 | if step == 0: 260 | median = geometric_median(x) 261 | self.ae.b_dec.data = median 262 | 263 | # Make sure the decoder is still unit-norm 264 | self.ae.set_decoder_norm_to_unit_norm() 265 | 266 | # compute the loss 267 | x = x.to(self.device) 268 | loss = self.loss(x, step=step) 269 | loss.backward() 270 | 271 | # clip grad norm and remove grads parallel to decoder directions 272 | t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0) 273 | self.ae.remove_gradient_parallel_to_decoder_directions() 274 | 275 | # do a training step 276 | self.optimizer.step() 277 | self.optimizer.zero_grad() 278 | self.scheduler.step() 279 | return loss.item() 280 | 281 | @property 282 | def config(self): 283 | return { 284 | 'trainer_class' : 'TrainerTopK', 285 | 'dict_class' : 'AutoEncoderTopK', 286 | 'lr' : self.lr, 287 | 'steps' : self.steps, 288 | 'seed' : self.seed, 289 | 'activation_dim' : self.ae.activation_dim, 290 | 'dict_size' : self.ae.dict_size, 291 | 'k': self.ae.k, 292 | 'device' : self.device, 293 | "layer" : self.layer, 294 | 'lm_name' : self.lm_name, 295 | 'wandb_name' : self.wandb_name, 296 | 'submodule_name' : self.submodule_name, 297 | } 298 | -------------------------------------------------------------------------------- /dictionary_learning/trainers/trainer.py: -------------------------------------------------------------------------------- 1 | class SAETrainer: 2 | """ 3 | Generic class for implementing SAE training algorithms 4 | """ 5 | def __init__(self, seed=None): 6 | self.seed = seed 7 | self.logging_parameters = [] 8 | 9 | def update(self, 10 | step, # index of step in training 11 | activations, # of shape [batch_size, d_submodule] 12 | ): 13 | pass # implemented by subclasses 14 | 15 | def get_logging_parameters(self): 16 | stats = {} 17 | for param in self.logging_parameters: 18 | if hasattr(self, param): 19 | stats[param] = getattr(self, param) 20 | else: 21 | print(f"Warning: {param} not found in {self}") 22 | return stats 23 | 24 | @property 25 | def config(self): 26 | return { 27 | 'wandb_name': 'trainer', 28 | } 29 | -------------------------------------------------------------------------------- /dictionary_learning/training.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training dictionaries 3 | """ 4 | 5 | import torch as t 6 | from .dictionary import AutoEncoder 7 | import os 8 | from tqdm import tqdm 9 | from .trainers.standard import StandardTrainer 10 | import wandb 11 | import json 12 | from .utils import cfg_filename 13 | # from .evaluation import evaluate 14 | 15 | def trainSAE( 16 | data, 17 | trainer_configs = [ 18 | { 19 | 'trainer' : StandardTrainer, 20 | 'dict_class' : AutoEncoder, 21 | 'activation_dim' : 512, 22 | 'dictionary_size' : 64*512, 23 | 'lr' : 1e-3, 24 | 'l1_penalty' : 1e-1, 25 | 'warmup_steps' : 1000, 26 | 'resample_steps' : None, 27 | 'seed' : None, 28 | 'wandb_name' : 'StandardTrainer', 29 | } 30 | ], 31 | steps=None, 32 | save_steps=None, 33 | save_dir=None, # use {run} to refer to wandb run 34 | log_steps=None, 35 | activations_split_by_head=False, # set to true if data is shape [batch, pos, num_head, head_dim/resid_dim] 36 | transcoder=False, 37 | ): 38 | """ 39 | Train SAEs using the given trainers 40 | """ 41 | 42 | trainers = [] 43 | for config in trainer_configs: 44 | trainer = config['trainer'] 45 | del config['trainer'] 46 | trainers.append( 47 | trainer( 48 | **config 49 | ) 50 | ) 51 | 52 | if log_steps is not None: 53 | ''' 54 | wandb.init( 55 | entity="sae-training", 56 | project="sae-training", 57 | config={f'{trainer.config["wandb_name"]}-{i}' : trainer.config for i, trainer in enumerate(trainers)} 58 | ) 59 | ''' 60 | # process save_dir in light of run name 61 | if save_dir is not None: 62 | save_dir = save_dir.format(run=wandb.run.name) 63 | 64 | # make save dirs, export config 65 | if save_dir is not None: 66 | save_dirs = [os.path.join(save_dir, f"{cfg_filename(trainer_config)}") for trainer_config in trainer_configs] 67 | for trainer, dir in zip(trainers, save_dirs): 68 | os.makedirs(dir, exist_ok=True) 69 | # save config 70 | config = {'trainer' : trainer.config} 71 | try: 72 | config['buffer'] = data.config 73 | except: pass 74 | with open(os.path.join(dir, "config.json"), 'w') as f: 75 | json.dump(config, f, indent=4) 76 | else: 77 | save_dirs = [None for _ in trainer_configs] 78 | 79 | for step, act in enumerate(tqdm(data, total=steps)): 80 | if steps is not None and step >= steps: 81 | break 82 | 83 | # logging 84 | if log_steps is not None and step % log_steps == 0: 85 | log = {} 86 | with t.no_grad(): 87 | 88 | # quick hack to make sure all trainers get the same x 89 | # TODO make this less hacky 90 | z = act.clone() 91 | for i, trainer in enumerate(trainers): 92 | act = z.clone() 93 | if activations_split_by_head: # x.shape: [batch, pos, n_heads, d_head] 94 | act = act[..., i, :] 95 | trainer_name = f'{trainer.config["wandb_name"]}-{i}' 96 | if not transcoder: 97 | act, act_hat, f, losslog = trainer.loss(act, step=step, logging=True) # act is x 98 | 99 | # L0 100 | l0 = (f != 0).float().sum(dim=-1).mean().item() 101 | # fraction of variance explained 102 | total_variance = t.var(act, dim=0).sum() 103 | residual_variance = t.var(act - act_hat, dim=0).sum() 104 | frac_variance_explained = (1 - residual_variance / total_variance) 105 | log[f'{trainer_name}/frac_variance_explained'] = frac_variance_explained.item() 106 | else: # transcoder 107 | x, x_hat, f, losslog = trainer.loss(act, step=step, logging=True) # act is x, y 108 | 109 | # L0 110 | l0 = (f != 0).float().sum(dim=-1).mean().item() 111 | 112 | # fraction of variance explained 113 | # TODO: adapt for transcoder 114 | # total_variance = t.var(x, dim=0).sum() 115 | # residual_variance = t.var(x - x_hat, dim=0).sum() 116 | # frac_variance_explained = (1 - residual_variance / total_variance) 117 | # log[f'{trainer_name}/frac_variance_explained'] = frac_variance_explained.item() 118 | 119 | # log parameters from training 120 | log.update({f'{trainer_name}/{k}' : v for k, v in losslog.items()}) 121 | log[f'{trainer_name}/l0'] = l0 122 | trainer_log = trainer.get_logging_parameters() 123 | for name, value in trainer_log.items(): 124 | log[f'{trainer_name}/{name}'] = value 125 | 126 | # TODO get this to work 127 | # metrics = evaluate( 128 | # trainer.ae, 129 | # data, 130 | # device=trainer.device 131 | # ) 132 | # log.update( 133 | # {f'trainer{i}/{k}' : v for k, v in metrics.items()} 134 | # ) 135 | wandb.log(log, step=step) 136 | 137 | # saving 138 | if save_steps is not None and step % save_steps == 0: 139 | for dir, trainer in zip(save_dirs, trainers): 140 | if dir is not None: 141 | if not os.path.exists(os.path.join(dir, "checkpoints")): 142 | os.mkdir(os.path.join(dir, "checkpoints")) 143 | t.save( 144 | trainer.ae.state_dict(), 145 | os.path.join(dir, "checkpoints", f"ae_{step}.pt") 146 | ) 147 | 148 | # training 149 | for trainer in trainers: 150 | trainer.update(step, act) 151 | 152 | # save final SAEs 153 | for save_dir, trainer in zip(save_dirs, trainers): 154 | if save_dir is not None: 155 | t.save(trainer.ae.state_dict(), os.path.join(save_dir, "ae.pt")) 156 | 157 | # End the wandb run 158 | ''' 159 | if log_steps is not None: 160 | wandb.finish() 161 | ''' -------------------------------------------------------------------------------- /dictionary_learning/utils.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import zstandard as zstd 3 | import io 4 | import json 5 | import argparse 6 | 7 | def hf_dataset_to_generator(dataset_name, split='train', streaming=True): 8 | dataset = load_dataset(dataset_name, split=split, streaming=streaming) 9 | 10 | def gen(): 11 | for x in iter(dataset): 12 | yield x['text'] 13 | 14 | return gen() 15 | 16 | def zst_to_generator(data_path): 17 | """ 18 | Load a dataset from a .jsonl.zst file. 19 | The jsonl entries is assumed to have a 'text' field 20 | """ 21 | compressed_file = open(data_path, 'rb') 22 | dctx = zstd.ZstdDecompressor() 23 | reader = dctx.stream_reader(compressed_file) 24 | text_stream = io.TextIOWrapper(reader, encoding='utf-8') 25 | def generator(): 26 | for line in text_stream: 27 | yield json.loads(line)['text'] 28 | return generator() 29 | 30 | def cfg_filename(cfg): 31 | result = [] 32 | for key in cfg: 33 | value = str(cfg[key]) 34 | value = value.replace("/", "") 35 | result.append(f"{key}:{value[-20:]}") 36 | return '_'.join(result) 37 | 38 | def str2bool(v): 39 | if isinstance(v, bool): 40 | return v 41 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 42 | return True 43 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 44 | return False 45 | else: 46 | raise argparse.ArgumentTypeError('Boolean value expected.') 47 | -------------------------------------------------------------------------------- /results/1on/l0_deltace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/1on/l0_deltace.png -------------------------------------------------------------------------------- /results/1on/l0_mse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/1on/l0_mse.png -------------------------------------------------------------------------------- /results/1on/l0_recovered.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/1on/l0_recovered.png -------------------------------------------------------------------------------- /results/1on_lb/1on_lb.csv: -------------------------------------------------------------------------------- 1 | "Name","State","Notes","User","Tags","Created","Runtime","Sweep","SwitchAutoEncoder-0.activation_dim","SwitchAutoEncoder-0.auxk_alpha","SwitchAutoEncoder-0.decay_start","SwitchAutoEncoder-0.device","SwitchAutoEncoder-0.dict_class","SwitchAutoEncoder-0.dict_size","SwitchAutoEncoder-0.experts","SwitchAutoEncoder-0.heaviside","SwitchAutoEncoder-0.k","SwitchAutoEncoder-0.layer","SwitchAutoEncoder-0.lb_alpha","SwitchAutoEncoder-0.lm_name","SwitchAutoEncoder-0.seed","SwitchAutoEncoder-0.steps","SwitchAutoEncoder-0.trainer","SwitchAutoEncoder-0.wandb_name","SwitchAutoEncoder-1.activation_dim","SwitchAutoEncoder-1.auxk_alpha","SwitchAutoEncoder-1.decay_start","SwitchAutoEncoder-1.device","SwitchAutoEncoder-1.dict_class","SwitchAutoEncoder-1.dict_size","SwitchAutoEncoder-1.experts","SwitchAutoEncoder-1.heaviside","SwitchAutoEncoder-1.k","SwitchAutoEncoder-1.layer","SwitchAutoEncoder-1.lb_alpha","SwitchAutoEncoder-1.lm_name","SwitchAutoEncoder-1.seed","SwitchAutoEncoder-1.steps","SwitchAutoEncoder-1.trainer","SwitchAutoEncoder-1.wandb_name","SwitchAutoEncoder-0/auxk_loss","SwitchAutoEncoder-0/cossim","SwitchAutoEncoder-0/dead_features","SwitchAutoEncoder-0/effective_l0","SwitchAutoEncoder-0/frac_alive","SwitchAutoEncoder-0/frac_recovered","SwitchAutoEncoder-0/frac_variance_explained","SwitchAutoEncoder-0/l0","SwitchAutoEncoder-0/l1_loss","SwitchAutoEncoder-0/l2_loss","SwitchAutoEncoder-0/l2_ratio","SwitchAutoEncoder-0/lb_loss","SwitchAutoEncoder-0/loss","SwitchAutoEncoder-0/loss_original","SwitchAutoEncoder-0/loss_reconstructed","SwitchAutoEncoder-0/loss_zero","SwitchAutoEncoder-0/mse_loss","SwitchAutoEncoder-1/auxk_loss","SwitchAutoEncoder-1/cossim","SwitchAutoEncoder-1/dead_features","SwitchAutoEncoder-1/effective_l0","SwitchAutoEncoder-1/frac_alive","SwitchAutoEncoder-1/frac_recovered","SwitchAutoEncoder-1/frac_variance_explained","SwitchAutoEncoder-1/l0","SwitchAutoEncoder-1/l1_loss","SwitchAutoEncoder-1/l2_loss","SwitchAutoEncoder-1/l2_ratio","SwitchAutoEncoder-1/lb_loss","SwitchAutoEncoder-1/loss","SwitchAutoEncoder-1/loss_original","SwitchAutoEncoder-1/loss_reconstructed","SwitchAutoEncoder-1/loss_zero","SwitchAutoEncoder-1/mse_loss" 2 | "clean-frog-6","finished","-","","","2024-07-30T01:01:23.000Z","47755","","768","0.03125","80000","cuda:7","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","100","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:7","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","300","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","0.000044464417442213744","0.9567948579788208","3228","64","0.00004069010537932627","0.9899426102638244","0.9863746166229248","253.58837890625","654.6430053710938","34.097957611083984","0.9548745155334472","1.0313754081726074","80446.0546875","3.3303747177124023","3.471522569656372","17.364582061767578","1241.935791015625","0.00004988056025467813","0.9448071718215942","8491","64","0.00004069010537932627","0.9861976504325868","0.9729151725769044","408.0152587890625","714.1879272460938","38.533409118652344","0.942345142364502","1.0331422090530396","239588.09375","3.26992130279541","3.467707872390747","17.59981346130371","1566.963623046875" 3 | "fanciful-oath-4","finished","-","","","2024-07-30T01:01:23.000Z","46459","","768","0.03125","80000","cuda:6","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","10","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:6","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","30","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","0.000005703236183762783","0.9625418186187744","52","64","0.00004069010537932627","0.9920895099639891","0.9881578087806702","195.4332275390625","607.35693359375","31.696577072143555","0.9622729420661926","1.0371432304382324","9045.2392578125","3.3303747177124023","3.441391944885254","17.364582061767578","1079.396484375","0.000008642778993817046","0.9617507457733154","79","64","0.00004069010537932627","0.9917919039726256","0.9810457825660706","192.5859375","579.9859008789062","32.029884338378906","0.9615445733070374","1.0366272926330566","24975.3359375","3.26992130279541","3.387542963027954","17.59981346130371","1096.583740234375" 4 | "wild-valley-4","finished","-","","","2024-07-30T01:01:23.000Z","45651","","768","0.03125","80000","cuda:5","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","1","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:5","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","3","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","7.480709314222622e-7","0.9632782936096193","7","64","0.00004069010537932627","0.9921401739120485","0.988391637802124","197.5400390625","603.214599609375","31.353612899780273","0.96281236410141","1.0378597974777222","1851.8035888671875","3.3303747177124023","3.4406819343566895","17.364582061767578","1058.0706787109375","0.000002888723429350648","0.9625215530395508","27","64","0.00004069010537932627","0.9922477006912231","0.9814417958259584","198.2266845703125","574.055419921875","31.630542755126953","0.9623088836669922","1.0359835624694824","3454.05517578125","3.26992130279541","3.3810112476348877","17.59981346130371","1073.6748046875" 5 | "different-feather-3","finished","-","","","2024-07-30T00:57:52.000Z","46911","","768","0.03125","80000","cuda:7","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","0.1","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:7","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","0.3","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","0.00004044335219077766","0.963446080684662","364","64","0.00004069010537932627","0.9921776056289672","0.9884382486343384","199.870361328125","601.8258056640625","31.283111572265625","0.9632078409194946","1.1113539934158323","1141.703369140625","3.3303747177124023","3.4401562213897705","17.364582061767578","1053.83154296875","0.000005634989520331146","0.9629814624786376","52","64","0.00004069010537932627","0.992292046546936","0.9816661477088928","199.752685546875","573.129638671875","31.42491912841797","0.962762176990509","1.045872926712036","1295.9554443359375","3.26992130279541","3.380375623703003","17.59981346130371","1060.68994140625" 6 | "feasible-plasma-1","finished","-","","","2024-07-30T00:57:40.000Z","47790","","768","0.03125","80000","cuda:6","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","0.01","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:6","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","0.03","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","0.000043194224417675287","0.9620879292488098","1624","64","0.00004069010537932627","0.991884469985962","0.988021969795227","203.5709228515625","604.99609375","31.848587036132812","0.9618685245513916","1.6018588542938232","1101.450927734375","3.3303747177124023","3.444269895553589","17.364582061767578","1091.7822265625","0.000042685824155341834","0.9621524214744568","1215","64","0.00004069010537932627","0.9919402003288268","0.9812399744987488","202.982666015625","575.129638671875","31.77358627319336","0.962187886238098","1.3047212362289429","1103.21728515625","3.26992130279541","3.3854167461395264","17.59981346130371","1085.34228515625" 7 | "summer-surf-2","finished","-","","","2024-07-30T00:57:40.000Z","47371","","768","0.03125","80000","cuda:5","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","0.001","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:5","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","0.003","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","0.000043277785152895376","0.9617789387702942","2094","64","0.00004069010537932627","0.9917656779289246","0.9879220128059388","204.457763671875","608.1683349609375","31.98760986328125","0.9616813659667968","1.7181545495986938","1099.705322265625","3.3303747177124023","3.445937156677246","17.364582061767578","1100.883056640625","0.00004343598266132176","0.9613968133926392","1650","64","0.00004069010537932627","0.9916406869888306","0.980884611606598","205.2655029296875","578.4508056640625","32.08662796020508","0.961195707321167","1.6460076570510864","1097.1473388671875","3.26992130279541","3.38970947265625","17.59981346130371","1105.91845703125" -------------------------------------------------------------------------------- /results/1on_lb/alpha_deltace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/1on_lb/alpha_deltace.png -------------------------------------------------------------------------------- /results/1on_lb/alpha_lossrec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/1on_lb/alpha_lossrec.png -------------------------------------------------------------------------------- /results/1on_lb/alpha_mse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/1on_lb/alpha_mse.png -------------------------------------------------------------------------------- /results/efficiency/efficiency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/efficiency/efficiency.png -------------------------------------------------------------------------------- /results/gated_lr/gated_l0_lossrec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/gated_lr/gated_l0_lossrec.png -------------------------------------------------------------------------------- /results/gated_lr/gated_l0_mse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/gated_lr/gated_l0_mse.png -------------------------------------------------------------------------------- /results/gated_lr/gated_mse_deltace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/gated_lr/gated_mse_deltace.png -------------------------------------------------------------------------------- /results/gated_lr/primary-gated-1e-3.csv: -------------------------------------------------------------------------------- 1 | "Name","State","Notes","User","Tags","Created","Runtime","Sweep","GatedSAETrainer-0.activation_dim","GatedSAETrainer-0.device","GatedSAETrainer-0.dict_class","GatedSAETrainer-0.dict_size","GatedSAETrainer-0.l1_penalty","GatedSAETrainer-0.layer","GatedSAETrainer-0.lm_name","GatedSAETrainer-0.lr","GatedSAETrainer-0.seed","GatedSAETrainer-0.trainer","GatedSAETrainer-0.wandb_name","GatedSAETrainer-0.warmup_steps","GatedSAETrainer-1.activation_dim","GatedSAETrainer-1.device","GatedSAETrainer-1.dict_class","GatedSAETrainer-1.dict_size","GatedSAETrainer-1.l1_penalty","GatedSAETrainer-1.layer","GatedSAETrainer-1.lm_name","GatedSAETrainer-1.lr","GatedSAETrainer-1.seed","GatedSAETrainer-1.trainer","GatedSAETrainer-1.wandb_name","GatedSAETrainer-1.warmup_steps","GatedSAETrainer-0/aux_loss","GatedSAETrainer-0/cossim","GatedSAETrainer-0/frac_alive","GatedSAETrainer-0/frac_recovered","GatedSAETrainer-0/frac_variance_explained","GatedSAETrainer-0/l0","GatedSAETrainer-0/l1_loss","GatedSAETrainer-0/l2_loss","GatedSAETrainer-0/l2_ratio","GatedSAETrainer-0/loss","GatedSAETrainer-0/loss_original","GatedSAETrainer-0/loss_reconstructed","GatedSAETrainer-0/loss_zero","GatedSAETrainer-0/mse_loss","GatedSAETrainer-0/sparsity_loss","GatedSAETrainer-1/aux_loss","GatedSAETrainer-1/cossim","GatedSAETrainer-1/frac_alive","GatedSAETrainer-1/frac_recovered","GatedSAETrainer-1/frac_variance_explained","GatedSAETrainer-1/l0","GatedSAETrainer-1/l1_loss","GatedSAETrainer-1/l2_loss","GatedSAETrainer-1/l2_ratio","GatedSAETrainer-1/loss","GatedSAETrainer-1/loss_original","GatedSAETrainer-1/loss_reconstructed","GatedSAETrainer-1/loss_zero","GatedSAETrainer-1/mse_loss","GatedSAETrainer-1/sparsity_loss" 2 | "sunny-planet-5","finished","-","","","2024-07-16T17:57:45.000Z","64564","","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","20.56","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","30","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","3218.92138671875","0.920775294303894","0.00004069010537932627","0.9675883054733276","0.9754146933555604","14.7557373046875","277.45941162109375","46.65650939941406","0.9305692911148072","9132.3642578125","3.3303747177124023","3.7852470874786377","17.364582061767578","2245.533203125","175.97865295410156","4237.42578125","0.8820405602455139","0.00004069010537932627","0.9322636723518372","0.943988561630249","7.1939697265625","159.96339416503906","55.96257400512695","0.882667064666748","10325.232421875","3.26992130279541","4.240576267242432","17.59981346130371","3240.766845703125","95.46398162841795" 3 | "clear-frog-4","finished","-","","","2024-07-16T17:57:39.000Z","65896","","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","9.65","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","14.09","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1701.8465576171875","0.9534698724746704","0.00004069010537932627","0.9877625703811646","0.9854307770729064","38.939208984375","344.64532470703125","36.01933670043945","0.9426600933074952","5541.484375","3.3303747177124023","3.5021169185638428","17.364582061767578","1338.38232421875","262.4046325683594","2458.522705078125","0.94084894657135","0.00004069010537932627","0.981767475605011","0.9708531498908995","25.7823486328125","310.1355285644531","40.50748062133789","0.9534173607826232","7219.072265625","3.26992130279541","3.5311920642852783","17.59981346130371","1692.3714599609375","211.6930694580078" 4 | "drawn-shadow-2","finished","-","","","2024-07-16T17:56:25.000Z","66134","","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","4.53","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","6.62","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1053.4622802734375","0.9675292372703552","0.00004069010537932627","0.9922187924385072","0.9881061911582948","92.7789306640625","432.4425048828125","31.512962341308594","0.937780499458313","3639.833984375","3.3303747177124023","3.4395782947540283","17.364582061767578","1114.6705322265625","375.8903198242187","1358.39306640625","0.9616646766662598","0.00004069010537932627","0.9908108115196228","0.9810226559638976","62.830810546875","380.2125244140625","32.97096633911133","0.9485290050506592","4437.47705078125","3.26992130279541","3.401602268218994","17.59981346130371","1121.4132080078125","303.4264221191406" 5 | "hearty-music-1","finished","-","","","2024-07-16T17:56:25.000Z","65688","","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","2.13","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","3.11","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","899.0625610351562","0.9798789620399476","0.00004069010537932627","0.994871973991394","0.9936583042144777","273.841796875","704.4373168945312","24.840396881103516","0.9749097228050232","2705.451416015625","3.3303747177124023","3.4023427963256836","17.364582061767578","684.5126953125","537.3072509765625","1063.873046875","0.9749206304550172","0.00004069010537932627","0.9942847490310668","0.9875259399414062","173.9610595703125","553.797119140625","27.179105758666992","0.981850266456604","3166.013916015625","3.26992130279541","3.3518197536468506","17.59981346130371","779.90478515625","412.6763000488281" 6 | "smooth-aardvark-2","finished","-","","","2024-07-16T17:56:25.000Z","65729","","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","1","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","1.46","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","725.191650390625","0.9857865571975708","0.00004069010537932627","0.9952702522277832","0.9964048862457277","481.456787109375","998.9766235351562","21.06471061706543","0.99408757686615","2107.7939453125","3.3303747177124023","3.3967533111572266","17.364582061767578","518.01513671875","808.7833251953125","741.2411499023438","0.9823729991912842","0.00004069010537932627","0.9945960640907288","0.9929369688034058","380.1922607421875","852.5096435546875","22.97112274169922","0.9866905212402344","2290.761474609375","3.26992130279541","3.347358465194702","17.59981346130371","580.014404296875","680.2507934570312" -------------------------------------------------------------------------------- /results/gated_lr/primary-gated-3e-4.csv: -------------------------------------------------------------------------------- 1 | "Name","State","Notes","User","Tags","Created","Runtime","Sweep","GatedSAETrainer-0.activation_dim","GatedSAETrainer-0.device","GatedSAETrainer-0.dict_class","GatedSAETrainer-0.dict_size","GatedSAETrainer-0.l1_penalty","GatedSAETrainer-0.layer","GatedSAETrainer-0.lm_name","GatedSAETrainer-0.lr","GatedSAETrainer-0.seed","GatedSAETrainer-0.trainer","GatedSAETrainer-0.wandb_name","GatedSAETrainer-0.warmup_steps","GatedSAETrainer-1.activation_dim","GatedSAETrainer-1.device","GatedSAETrainer-1.dict_class","GatedSAETrainer-1.dict_size","GatedSAETrainer-1.l1_penalty","GatedSAETrainer-1.layer","GatedSAETrainer-1.lm_name","GatedSAETrainer-1.lr","GatedSAETrainer-1.seed","GatedSAETrainer-1.trainer","GatedSAETrainer-1.wandb_name","GatedSAETrainer-1.warmup_steps","GatedSAETrainer-0/aux_loss","GatedSAETrainer-0/cossim","GatedSAETrainer-0/frac_alive","GatedSAETrainer-0/frac_recovered","GatedSAETrainer-0/frac_variance_explained","GatedSAETrainer-0/l0","GatedSAETrainer-0/l1_loss","GatedSAETrainer-0/l2_loss","GatedSAETrainer-0/l2_ratio","GatedSAETrainer-0/loss","GatedSAETrainer-0/loss_original","GatedSAETrainer-0/loss_reconstructed","GatedSAETrainer-0/loss_zero","GatedSAETrainer-0/mse_loss","GatedSAETrainer-0/sparsity_loss","GatedSAETrainer-1/aux_loss","GatedSAETrainer-1/cossim","GatedSAETrainer-1/frac_alive","GatedSAETrainer-1/frac_recovered","GatedSAETrainer-1/frac_variance_explained","GatedSAETrainer-1/l0","GatedSAETrainer-1/l1_loss","GatedSAETrainer-1/l2_loss","GatedSAETrainer-1/l2_ratio","GatedSAETrainer-1/loss","GatedSAETrainer-1/loss_original","GatedSAETrainer-1/loss_reconstructed","GatedSAETrainer-1/loss_zero","GatedSAETrainer-1/mse_loss","GatedSAETrainer-1/sparsity_loss" 2 | "rare-firefly-6","finished","-","","","2024-07-11T18:51:52.000Z","51591","","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","5","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","6","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1163.14306640625","0.973246932029724","0.00004069010537932627","0.9953534007072448","0.9913633465766908","127.900146484375","528.8916015625","27.75778579711914","0.9604530334472656","3927.76611328125","3.3303747177124023","3.3955864906311035","17.364582061767578","791.554931640625","394.5193481445313","1230.1353759765625","0.9700021743774414","0.00004069010537932627","0.9947134852409364","0.984986126422882","104.4046630859375","507.0380859375","29.08409881591797","0.9812511205673218","4367.89501953125","3.26992130279541","3.3456764221191406","17.59981346130371","870.0953979492188","376.71624755859375" 3 | "dashing-lion-5","finished","-","","","2024-07-11T05:25:59.000Z","50846","","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","20.56","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","30","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","3559.26806640625","0.9084912538528442","0.00004069010537932627","0.9579382538795472","0.9719541668891908","12.566650390625","378.3607177734375","49.76805114746094","0.895888090133667","11920.619140625","3.3303747177124023","3.92067813873291","17.364582061767578","2559.174560546875","280.8070068359375","4629.88525390625","0.8717569708824158","0.00004069010537932627","0.9255402088165284","0.9396619200706482","8.206787109375","312.36065673828125","58.12748718261719","0.8806691765785217","15425.427734375","3.26992130279541","4.336921691894531","17.59981346130371","3492.701171875","243.12893676757812" 4 | "graceful-totem-4","finished","-","","","2024-07-11T05:25:45.000Z","51864","","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","9.65","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","14.09","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1701.7183837890625","0.9564532041549684","0.00004069010537932627","0.9901224970817566","0.9862537980079652","48.8880615234375","440.683349609375","34.85649871826172","0.9686530232429504","6070.19482421875","3.3303747177124023","3.4689981937408447","17.364582061767578","1253.7666015625","322.1306457519531","2614.995361328125","0.934950828552246","0.00004069010537932627","0.9785899519920348","0.968268096446991","23.200927734375","387.74371337890625","42.21244812011719","0.9180249571800232","8781.263671875","3.26992130279541","3.5767250061035156","17.59981346130371","1837.4671630859375","305.5757751464844" 5 | "silver-puddle-3","finished","-","","","2024-07-11T05:25:35.000Z","51579","","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","4.53","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","6.62","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1059.000244140625","0.975863754749298","0.00004069010537932627","0.9960606098175048","0.9922483563423156","159.738525390625","578.526123046875","26.311534881591797","0.9801684617996216","3686.526611328125","3.3303747177124023","3.3856611251831055","17.364582061767578","710.6533203125","420.4597473144531","1368.194091796875","0.9670422077178956","0.00004069010537932627","0.99338698387146","0.9835944175720216","82.66162109375","464.7086181640625","30.388513565063477","0.9570202827453612","4679.21875","3.26992130279541","3.364684820175171","17.59981346130371","950.4469604492188","356.64434814453125" 6 | "generous-fog-2","finished","-","","","2024-07-11T05:25:20.000Z","51422","","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","2.13","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","3.11","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","603.53125","0.9911146759986876","0.00004069010537932627","0.998424232006073","0.9973034858703612","529.562744140625","1140.15185546875","16.387954711914062","0.993000864982605","2631.47314453125","3.3303747177124023","3.352489471435547","17.364582061767578","278.48040771484375","825.9996337890625","819.4901123046875","0.9827598333358764","0.00004069010537932627","0.9972550868988036","0.9914193153381348","289.6943359375","723.2354736328125","22.326417922973633","0.9679235219955444","3019.125","3.26992130279541","3.3092551231384277","17.59981346130371","508.3369140625","546.3135986328125" 7 | "pleasant-oath-1","finished","-","","","2024-07-11T05:21:07.000Z","52028","","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","1","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","1.46","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","245.25555419921875","0.9967520833015442","0.00004069010537932627","0.9992234110832214","0.9996508359909058","645.68798828125","1738.0128173828125","9.955805778503418","0.995638370513916","1791.7958984375","3.3303747177124023","3.341273546218872","17.364582061767578","102.31087493896484","1449.95556640625","388.83349609375","0.9960089921951294","0.00004069010537932627","0.9988710284233092","0.9988107085227966","565.810791015625","1553.6287841796875","10.871599197387695","0.9927268028259276","2247.65771484375","3.26992130279541","3.2860991954803467","17.59981346130371","119.66344451904295","1189.346923828125" -------------------------------------------------------------------------------- /results/gated_lr/primary-gated-5e-5.csv: -------------------------------------------------------------------------------- 1 | "Name","State","Notes","User","Tags","Created","Runtime","Sweep","GatedSAETrainer-0.activation_dim","GatedSAETrainer-0.device","GatedSAETrainer-0.dict_class","GatedSAETrainer-0.dict_size","GatedSAETrainer-0.l1_penalty","GatedSAETrainer-0.layer","GatedSAETrainer-0.lm_name","GatedSAETrainer-0.lr","GatedSAETrainer-0.seed","GatedSAETrainer-0.trainer","GatedSAETrainer-0.wandb_name","GatedSAETrainer-0.warmup_steps","GatedSAETrainer-1.activation_dim","GatedSAETrainer-1.device","GatedSAETrainer-1.dict_class","GatedSAETrainer-1.dict_size","GatedSAETrainer-1.l1_penalty","GatedSAETrainer-1.layer","GatedSAETrainer-1.lm_name","GatedSAETrainer-1.lr","GatedSAETrainer-1.seed","GatedSAETrainer-1.trainer","GatedSAETrainer-1.wandb_name","GatedSAETrainer-1.warmup_steps","GatedSAETrainer-0/aux_loss","GatedSAETrainer-0/cossim","GatedSAETrainer-0/frac_alive","GatedSAETrainer-0/frac_recovered","GatedSAETrainer-0/frac_variance_explained","GatedSAETrainer-0/l0","GatedSAETrainer-0/l1_loss","GatedSAETrainer-0/l2_loss","GatedSAETrainer-0/l2_ratio","GatedSAETrainer-0/loss","GatedSAETrainer-0/loss_original","GatedSAETrainer-0/loss_reconstructed","GatedSAETrainer-0/loss_zero","GatedSAETrainer-0/mse_loss","GatedSAETrainer-0/sparsity_loss","GatedSAETrainer-1/aux_loss","GatedSAETrainer-1/cossim","GatedSAETrainer-1/frac_alive","GatedSAETrainer-1/frac_recovered","GatedSAETrainer-1/frac_variance_explained","GatedSAETrainer-1/l0","GatedSAETrainer-1/l1_loss","GatedSAETrainer-1/l2_loss","GatedSAETrainer-1/l2_ratio","GatedSAETrainer-1/loss","GatedSAETrainer-1/loss_original","GatedSAETrainer-1/loss_reconstructed","GatedSAETrainer-1/loss_zero","GatedSAETrainer-1/mse_loss","GatedSAETrainer-1/sparsity_loss" 2 | "faithful-yogurt-5","finished","-","","","2024-07-13T08:40:04.000Z","78097","","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","20.56","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","30","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","21934.6796875","0.9164952039718628","0.00004069010537932627","0.9665752053260804","0.974133312702179","24.208984375","766.954833984375","47.832576751708984","0.9202543497085572","32263.42578125","3.3303747177124023","3.799464702606201","17.364582061767578","2357.88525390625","385.6487121582031","17812.669921875","0.875127911567688","0.00004069010537932627","0.9306480288505554","0.9410187005996704","11.570068359375","473.6762390136719","57.49211883544922","0.8786675930023193","30521.861328125","3.26992130279541","4.263727188110352","17.59981346130371","3412.41162109375","309.806396484375" 3 | "absurd-armadillo-4","finished","-","","","2024-07-13T08:39:48.000Z","51872","","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","9.65","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","14.09","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","6471.3515625","0.9502999782562256","0.00004069010537932627","0.9877426028251648","0.9842393398284912","86.309326171875","1533.554443359375","37.37459945678711","0.955512762069702","18172.8984375","3.3303747177124023","3.5023977756500244","17.364582061767578","1438.06787109375","1062.67578125","8154.59375","0.939529836177826","0.00004069010537932627","0.982342541217804","0.9703220129013062","35.64306640625","704.9683227539062","40.845314025878906","0.9403757452964784","17816.7109375","3.26992130279541","3.522951364517212","17.59981346130371","1717.0882568359375","564.1171264648438" 4 | "solar-smoke-2","finished","-","","","2024-07-13T08:39:44.000Z","52969","","768","cuda:3","dictionary_learning.dictionary.GatedAutoEncoder","24576","1","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:3","dictionary_learning.dictionary.GatedAutoEncoder","24576","1.46","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","176.92796325683594","0.999501407146454","0.00004069010537932627","0.9996459484100342","0.9998784065246582","733.0106201171875","2288.524169921875","3.913240909576416","0.996691107749939","2142.02099609375","3.3303747177124023","3.335343599319458","17.364582061767578","15.964256286621094","1948.878173828125","357.5518798828125","0.9984543323516846","0.00004069010537932627","0.9995059967041016","0.9992587566375732","805.2034912109375","2221.821044921875","6.802419662475586","0.9969862699508668","3115.2333984375","3.26992130279541","3.2770004272460938","17.59981346130371","48.07682037353515","1850.6416015625" 5 | "super-pond-1","finished","-","","","2024-07-13T08:39:44.000Z","51669","","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","2.13","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","3.11","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","3244.373046875","0.9685524702072144","0.00004069010537932627","0.9927898645401","0.9899781346321106","523.408447265625","4199.6708984375","30.41826057434082","0.9841817617416382","10093.599609375","3.3303747177124023","3.431563138961792","17.364582061767578","969.7067260742188","2751.1318359375","1491.67626953125","0.9849917888641356","0.00004069010537932627","0.9976735711097716","0.9923139214515686","530.9619140625","1741.81884765625","20.90458297729492","0.987014651298523","6693.18212890625","3.26992130279541","3.3032593727111816","17.59981346130371","447.7106323242187","1513.2750244140625" 6 | "serene-plant-2","finished","-","","","2024-07-13T08:39:44.000Z","51594","","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","4.53","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","6.62","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","3213.54931640625","0.964131236076355","0.00004069010537932627","0.9926445484161376","0.9886342883110046","308.4468994140625","2597.927734375","31.888023376464844","0.9749326705932616","12468.984375","3.3303747177124023","3.4336023330688477","17.364582061767578","1053.774658203125","1792.6346435546875","8968.025390625","0.9365889430046082","0.00004069010537932627","0.9807283878326416","0.96932053565979","95.2742919921875","1390.1123046875","41.98253631591797","0.941771388053894","17921.384765625","3.26992130279541","3.546081781387329","17.59981346130371","1815.49267578125","1141.9906005859375" -------------------------------------------------------------------------------- /results/heaviside_softmax/experts16h.csv: -------------------------------------------------------------------------------- 1 | ,experts,heaviside,l0,mse_loss,frac_recovered,loss_original,loss_reconstructed,delta_ce 2 | 7,16.0,1.0,8.0,2643.10888671875,0.9520765542984008,3.26992130279541,3.9566588401794434,0.6867375373840332 3 | 5,16.0,1.0,32.0,1720.351806640625,0.9821312427520752,3.26992130279541,3.525978803634644,0.25605750083923384 4 | 3,16.0,1.0,64.0,1384.42578125,0.9890355467796326,3.26992130279541,3.4270408153533936,0.1571195125579834 5 | 1,16.0,1.0,128.0,1037.333251953125,0.9934381246566772,3.26992130279541,3.363952159881592,0.09403085708618164 6 | -------------------------------------------------------------------------------- /results/heaviside_softmax/experts16s.csv: -------------------------------------------------------------------------------- 1 | ,experts,heaviside,l0,mse_loss,frac_recovered,loss_original,loss_reconstructed,delta_ce 2 | 6,16.0,0.0,8.0,2345.244384765625,0.9605113863945008,3.3303747177124023,3.884566307067871,0.5541915893554688 3 | 4,16.0,0.0,32.0,1412.8662109375,0.9869770407676696,3.3303747177124023,3.513141870498657,0.18276715278625444 4 | 2,16.0,0.0,64.0,1118.0126953125,0.9918711185455322,3.3303747177124023,3.444457054138184,0.1140823364257817 5 | 0,16.0,0.0,128.0,856.46142578125,0.9954178333282472,3.3303747177124023,3.394681453704834,0.06430673599243164 6 | -------------------------------------------------------------------------------- /results/heaviside_softmax/experts_l0_deltace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/heaviside_softmax/experts_l0_deltace.png -------------------------------------------------------------------------------- /results/heaviside_softmax/experts_l0_lossrec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/heaviside_softmax/experts_l0_lossrec.png -------------------------------------------------------------------------------- /results/heaviside_softmax/experts_l0_mse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/heaviside_softmax/experts_l0_mse.png -------------------------------------------------------------------------------- /results/heaviside_softmax/primary-experts.csv: -------------------------------------------------------------------------------- 1 | "Name","State","Notes","User","Tags","Created","Runtime","Sweep","SwitchAutoEncoder-0.activation_dim","SwitchAutoEncoder-0.auxk_alpha","SwitchAutoEncoder-0.decay_start","SwitchAutoEncoder-0.device","SwitchAutoEncoder-0.dict_class","SwitchAutoEncoder-0.dict_size","SwitchAutoEncoder-0.experts","SwitchAutoEncoder-0.heaviside","SwitchAutoEncoder-0.k","SwitchAutoEncoder-0.layer","SwitchAutoEncoder-0.lm_name","SwitchAutoEncoder-0.seed","SwitchAutoEncoder-0.steps","SwitchAutoEncoder-0.trainer","SwitchAutoEncoder-0.wandb_name","SwitchAutoEncoder-1.activation_dim","SwitchAutoEncoder-1.auxk_alpha","SwitchAutoEncoder-1.decay_start","SwitchAutoEncoder-1.device","SwitchAutoEncoder-1.dict_class","SwitchAutoEncoder-1.dict_size","SwitchAutoEncoder-1.experts","SwitchAutoEncoder-1.heaviside","SwitchAutoEncoder-1.k","SwitchAutoEncoder-1.layer","SwitchAutoEncoder-1.lm_name","SwitchAutoEncoder-1.seed","SwitchAutoEncoder-1.steps","SwitchAutoEncoder-1.trainer","SwitchAutoEncoder-1.wandb_name","SwitchAutoEncoder-0/auxk_loss","SwitchAutoEncoder-0/cossim","SwitchAutoEncoder-0/dead_features","SwitchAutoEncoder-0/effective_l0","SwitchAutoEncoder-0/frac_alive","SwitchAutoEncoder-0/frac_recovered","SwitchAutoEncoder-0/frac_variance_explained","SwitchAutoEncoder-0/l0","SwitchAutoEncoder-0/l1_loss","SwitchAutoEncoder-0/l2_loss","SwitchAutoEncoder-0/l2_ratio","SwitchAutoEncoder-0/loss","SwitchAutoEncoder-0/loss_original","SwitchAutoEncoder-0/loss_reconstructed","SwitchAutoEncoder-0/loss_zero","SwitchAutoEncoder-0/mse_loss","SwitchAutoEncoder-1/auxk_loss","SwitchAutoEncoder-1/cossim","SwitchAutoEncoder-1/dead_features","SwitchAutoEncoder-1/effective_l0","SwitchAutoEncoder-1/frac_alive","SwitchAutoEncoder-1/frac_recovered","SwitchAutoEncoder-1/frac_variance_explained","SwitchAutoEncoder-1/l0","SwitchAutoEncoder-1/l1_loss","SwitchAutoEncoder-1/l2_loss","SwitchAutoEncoder-1/l2_ratio","SwitchAutoEncoder-1/loss","SwitchAutoEncoder-1/loss_original","SwitchAutoEncoder-1/loss_reconstructed","SwitchAutoEncoder-1/loss_zero","SwitchAutoEncoder-1/mse_loss" 2 | "leafy-frost-5","finished","-","","","2024-07-13T06:09:56.000Z","44661","","768","0.03125","80000","cuda:4","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","false","128","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:4","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","true","128","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","0.00003293702320661396","0.970380425453186","2266","128","0.00004069010537932627","0.9954178333282472","0.9893406629562378","346.756103515625","929.089599609375","27.995044708251953","0.970288336277008","846.9383544921875","3.3303747177124023","3.394681453704834","17.364582061767578","856.46142578125","0","0.9639195203781128","0","128","0.00004069010537932627","0.9934381246566772","0.987186849117279","344.84912109375","932.871826171875","31.138887405395508","0.9640032052993774","1024.9150390625","3.26992130279541","3.363952159881592","17.59981346130371","1037.333251953125" 3 | "winter-moon-4","finished","-","","","2024-07-13T06:09:42.000Z","44679","","768","0.03125","80000","cuda:7","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","false","64","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:7","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","true","64","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","0.00003745846333913505","0.9611231684684752","2292","64","0.00004069010537932627","0.9918711185455322","0.9860854744911194","221.5013427734375","616.9510498046875","32.06147384643555","0.9609488248825072","1105.567626953125","3.3303747177124023","3.4444570541381836","17.364582061767578","1118.0126953125","0","0.9514711499214172","0","64","0.00004069010537932627","0.9890355467796326","0.9828994274139404","234.9410400390625","667.9144287109375","35.96377182006836","0.9516751170158386","1366.359130859375","3.26992130279541","3.4270408153533936","17.59981346130371","1384.42578125" 4 | "likely-aardvark-3","finished","-","","","2024-07-13T06:09:22.000Z","43728","","768","0.03125","80000","cuda:6","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","false","32","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:6","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","true","32","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","0.00004270674253348261","0.9505878686904908","3495","32","0.00004069010537932627","0.9869770407676696","0.9824159145355223","147.52978515625","473.6759033203125","36.09650421142578","0.9501758217811584","1390.8076171875","3.3303747177124023","3.5131418704986572","17.364582061767578","1412.8662109375","0","0.9392642974853516","0","32","0.00004069010537932627","0.9821312427520752","0.9787499904632568","158.9560546875","509.13079833984375","40.07820129394531","0.9391745328903198","1695.3155517578125","3.26992130279541","3.5259788036346436","17.59981346130371","1720.351806640625" 5 | "genial-dawn-2","finished","-","","","2024-07-13T06:09:08.000Z","46116","","768","0.03125","80000","cuda:5","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","false","8","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:5","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","true","8","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","0.00006340954860206693","0.91660076379776","14566","8","0.00004069010537932627","0.9605113863945008","0.970811665058136","115.51416015625","384.5042419433594","46.83196258544922","0.9166349172592164","2320.48779296875","3.3303747177124023","3.884566307067871","17.364582061767578","2345.244384765625","0.00006867974298074841","0.905162513256073","6606","8","0.00004069010537932627","0.9520765542984008","0.967352032661438","137.64208984375","417.4818420410156","49.77260971069336","0.905208945274353","2604.93310546875","3.26992130279541","3.9566588401794434","17.59981346130371","2643.10888671875" -------------------------------------------------------------------------------- /results/load_balance/alpha_deltace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/load_balance/alpha_deltace.png -------------------------------------------------------------------------------- /results/load_balance/alpha_lossrec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/load_balance/alpha_lossrec.png -------------------------------------------------------------------------------- /results/load_balance/alpha_mse.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/load_balance/alpha_mse.pdf -------------------------------------------------------------------------------- /results/load_balance/alpha_mse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/load_balance/alpha_mse.png -------------------------------------------------------------------------------- /results/primary/flopmatch_l0_deltace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/flopmatch_l0_deltace.png -------------------------------------------------------------------------------- /results/primary/flopmatch_l0_lossrec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/flopmatch_l0_lossrec.png -------------------------------------------------------------------------------- /results/primary/flopmatch_l0_mse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/flopmatch_l0_mse.png -------------------------------------------------------------------------------- /results/primary/flopmatch_mse_deltace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/flopmatch_mse_deltace.png -------------------------------------------------------------------------------- /results/primary/l0_deltace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/l0_deltace.png -------------------------------------------------------------------------------- /results/primary/l0_lossrec.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/l0_lossrec.png -------------------------------------------------------------------------------- /results/primary/l0_mse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/l0_mse.png -------------------------------------------------------------------------------- /results/primary/mse_deltace.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/mse_deltace.png -------------------------------------------------------------------------------- /results/primary/primary-gated-clean.csv: -------------------------------------------------------------------------------- 1 | ,l0,mse_loss,frac_recovered,delta_ce 2 | 3,8.206787109375,3492.701171875,0.9255402088165284,1.067000389099121 3 | 2,12.566650390625,2559.174560546875,0.9579382538795472,0.5903034210205078 4 | 5,23.200927734375,1837.4671630859373,0.9785899519920348,0.3068037033081059 5 | 4,48.8880615234375,1253.7666015625,0.9901224970817566,0.13862347602844238 6 | 7,82.66162109375,950.4469604492188,0.99338698387146,0.09476351737976074 7 | 1,104.4046630859375,870.0953979492188,0.9947134852409364,0.07575511932373047 8 | 0,127.900146484375,791.554931640625,0.9953534007072448,0.06521177291870117 9 | 6,159.738525390625,710.6533203125,0.9960606098175048,0.05528640747070268 10 | 9,289.6943359375,508.3369140625,0.9972550868988036,0.03933382034301802 -------------------------------------------------------------------------------- /results/primary/primary-gated.csv: -------------------------------------------------------------------------------- 1 | "Name","State","Notes","User","Tags","Created","Runtime","Sweep","GatedSAETrainer-0.activation_dim","GatedSAETrainer-0.device","GatedSAETrainer-0.dict_class","GatedSAETrainer-0.dict_size","GatedSAETrainer-0.l1_penalty","GatedSAETrainer-0.layer","GatedSAETrainer-0.lm_name","GatedSAETrainer-0.lr","GatedSAETrainer-0.seed","GatedSAETrainer-0.trainer","GatedSAETrainer-0.wandb_name","GatedSAETrainer-0.warmup_steps","GatedSAETrainer-1.activation_dim","GatedSAETrainer-1.device","GatedSAETrainer-1.dict_class","GatedSAETrainer-1.dict_size","GatedSAETrainer-1.l1_penalty","GatedSAETrainer-1.layer","GatedSAETrainer-1.lm_name","GatedSAETrainer-1.lr","GatedSAETrainer-1.seed","GatedSAETrainer-1.trainer","GatedSAETrainer-1.wandb_name","GatedSAETrainer-1.warmup_steps","GatedSAETrainer-0/aux_loss","GatedSAETrainer-0/cossim","GatedSAETrainer-0/frac_alive","GatedSAETrainer-0/frac_recovered","GatedSAETrainer-0/frac_variance_explained","GatedSAETrainer-0/l0","GatedSAETrainer-0/l1_loss","GatedSAETrainer-0/l2_loss","GatedSAETrainer-0/l2_ratio","GatedSAETrainer-0/loss","GatedSAETrainer-0/loss_original","GatedSAETrainer-0/loss_reconstructed","GatedSAETrainer-0/loss_zero","GatedSAETrainer-0/mse_loss","GatedSAETrainer-0/sparsity_loss","GatedSAETrainer-1/aux_loss","GatedSAETrainer-1/cossim","GatedSAETrainer-1/frac_alive","GatedSAETrainer-1/frac_recovered","GatedSAETrainer-1/frac_variance_explained","GatedSAETrainer-1/l0","GatedSAETrainer-1/l1_loss","GatedSAETrainer-1/l2_loss","GatedSAETrainer-1/l2_ratio","GatedSAETrainer-1/loss","GatedSAETrainer-1/loss_original","GatedSAETrainer-1/loss_reconstructed","GatedSAETrainer-1/loss_zero","GatedSAETrainer-1/mse_loss","GatedSAETrainer-1/sparsity_loss" 2 | "rare-firefly-6","finished","-","","","2024-07-11T18:51:52.000Z","51591","","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","5","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","6","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1163.14306640625","0.973246932029724","0.00004069010537932627","0.9953534007072448","0.9913633465766908","127.900146484375","528.8916015625","27.75778579711914","0.9604530334472656","3927.76611328125","3.3303747177124023","3.3955864906311035","17.364582061767578","791.554931640625","394.5193481445313","1230.1353759765625","0.9700021743774414","0.00004069010537932627","0.9947134852409364","0.984986126422882","104.4046630859375","507.0380859375","29.08409881591797","0.9812511205673218","4367.89501953125","3.26992130279541","3.3456764221191406","17.59981346130371","870.0953979492188","376.71624755859375" 3 | "dashing-lion-5","finished","-","","","2024-07-11T05:25:59.000Z","50846","","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","20.56","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","30","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","3559.26806640625","0.9084912538528442","0.00004069010537932627","0.9579382538795472","0.9719541668891908","12.566650390625","378.3607177734375","49.76805114746094","0.895888090133667","11920.619140625","3.3303747177124023","3.92067813873291","17.364582061767578","2559.174560546875","280.8070068359375","4629.88525390625","0.8717569708824158","0.00004069010537932627","0.9255402088165284","0.9396619200706482","8.206787109375","312.36065673828125","58.12748718261719","0.8806691765785217","15425.427734375","3.26992130279541","4.336921691894531","17.59981346130371","3492.701171875","243.12893676757812" 4 | "graceful-totem-4","finished","-","","","2024-07-11T05:25:45.000Z","51864","","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","9.65","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","14.09","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1701.7183837890625","0.9564532041549684","0.00004069010537932627","0.9901224970817566","0.9862537980079652","48.8880615234375","440.683349609375","34.85649871826172","0.9686530232429504","6070.19482421875","3.3303747177124023","3.4689981937408447","17.364582061767578","1253.7666015625","322.1306457519531","2614.995361328125","0.934950828552246","0.00004069010537932627","0.9785899519920348","0.968268096446991","23.200927734375","387.74371337890625","42.21244812011719","0.9180249571800232","8781.263671875","3.26992130279541","3.5767250061035156","17.59981346130371","1837.4671630859375","305.5757751464844" 5 | "silver-puddle-3","finished","-","","","2024-07-11T05:25:35.000Z","51579","","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","4.53","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","6.62","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1059.000244140625","0.975863754749298","0.00004069010537932627","0.9960606098175048","0.9922483563423156","159.738525390625","578.526123046875","26.311534881591797","0.9801684617996216","3686.526611328125","3.3303747177124023","3.3856611251831055","17.364582061767578","710.6533203125","420.4597473144531","1368.194091796875","0.9670422077178956","0.00004069010537932627","0.99338698387146","0.9835944175720216","82.66162109375","464.7086181640625","30.388513565063477","0.9570202827453612","4679.21875","3.26992130279541","3.364684820175171","17.59981346130371","950.4469604492188","356.64434814453125" 6 | "generous-fog-2","finished","-","","","2024-07-11T05:25:20.000Z","51422","","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","2.13","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","3.11","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","603.53125","0.9911146759986876","0.00004069010537932627","0.998424232006073","0.9973034858703612","529.562744140625","1140.15185546875","16.387954711914062","0.993000864982605","2631.47314453125","3.3303747177124023","3.352489471435547","17.364582061767578","278.48040771484375","825.9996337890625","819.4901123046875","0.9827598333358764","0.00004069010537932627","0.9972550868988036","0.9914193153381348","289.6943359375","723.2354736328125","22.326417922973633","0.9679235219955444","3019.125","3.26992130279541","3.3092551231384277","17.59981346130371","508.3369140625","546.3135986328125" 7 | "pleasant-oath-1","finished","-","","","2024-07-11T05:21:07.000Z","52028","","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","1","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","1.46","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","245.25555419921875","0.9967520833015442","0.00004069010537932627","0.9992234110832214","0.9996508359909058","645.68798828125","1738.0128173828125","9.955805778503418","0.995638370513916","1791.7958984375","3.3303747177124023","3.341273546218872","17.364582061767578","102.31087493896484","1449.95556640625","388.83349609375","0.9960089921951294","0.00004069010537932627","0.9988710284233092","0.9988107085227966","565.810791015625","1553.6287841796875","10.871599197387695","0.9927268028259276","2247.65771484375","3.26992130279541","3.2860991954803467","17.59981346130371","119.66344451904295","1189.346923828125" -------------------------------------------------------------------------------- /results/primary/primary-relu-clean.csv: -------------------------------------------------------------------------------- 1 | ,l0,mse_loss,frac_recovered,delta_ce 2 | 5,15.46533203125,5551.58984375,0.900568962097168,1.4248356819152832 3 | 3,15.949462890625,3971.114501953125,0.9313362836837769,0.9839439392089844 4 | 9,16.97021484375,3336.4921875,0.9437114000320436,0.8066086769104004 5 | 11,21.419921875,2474.707763671875,0.9623249769210817,0.5398788452148438 6 | 13,31.735595703125,1932.8282470703125,0.974922776222229,0.35935378074645996 7 | 12,44.6102294921875,1713.539306640625,0.979756772518158,0.28409790992736816 8 | 15,53.3970947265625,1546.72998046875,0.9833840727806092,0.23810505867004395 9 | 14,75.8472900390625,1370.767333984375,0.9866704344749452,0.18707013130187988 10 | 17,101.04296875,1230.6522216796875,0.9891743063926696,0.15513181686401367 11 | 16,150.755615234375,1069.6497802734375,0.9915485978126526,0.11860823631286666 12 | 1,213.8397216796875,946.8670654296876,0.9934131503105164,0.09438896179199219 -------------------------------------------------------------------------------- /results/primary/primary-topk-clean.csv: -------------------------------------------------------------------------------- 1 | ,l0,mse_loss,frac_recovered,delta_ce 2 | 2,8.0,2219.68701171875,0.9644190669059752,0.49935078620910645 3 | 3,16.0,1562.1041259765625,0.9841692447662354,0.2268538475036621 4 | 4,32.0,1142.533447265625,0.9910413026809692,0.12572836875915527 5 | 5,48.0,991.705810546875,0.9935886263847352,0.0918736457824707 6 | 6,64.0,888.516845703125,0.9945665597915648,0.07625389099121094 7 | 7,96.0,761.394287109375,0.9957606792449952,0.060749053955078125 8 | 0,128.0,661.279541015625,0.9967961311340332,0.04496359825134277 9 | 1,192.0,522.2657470703125,0.9979140758514404,0.029891014099121094 10 | -------------------------------------------------------------------------------- /results/primary/primary-topk.csv: -------------------------------------------------------------------------------- 1 | "Name","State","Notes","User","Tags","Created","Runtime","Sweep","AutoEncoderTopK-0.activation_dim","AutoEncoderTopK-0.auxk_alpha","AutoEncoderTopK-0.decay_start","AutoEncoderTopK-0.device","AutoEncoderTopK-0.dict_class","AutoEncoderTopK-0.dict_size","AutoEncoderTopK-0.k","AutoEncoderTopK-0.layer","AutoEncoderTopK-0.lm_name","AutoEncoderTopK-0.seed","AutoEncoderTopK-0.steps","AutoEncoderTopK-0.trainer","AutoEncoderTopK-0.wandb_name","AutoEncoderTopK-1.activation_dim","AutoEncoderTopK-1.auxk_alpha","AutoEncoderTopK-1.decay_start","AutoEncoderTopK-1.device","AutoEncoderTopK-1.dict_class","AutoEncoderTopK-1.dict_size","AutoEncoderTopK-1.k","AutoEncoderTopK-1.layer","AutoEncoderTopK-1.lm_name","AutoEncoderTopK-1.seed","AutoEncoderTopK-1.steps","AutoEncoderTopK-1.trainer","AutoEncoderTopK-1.wandb_name","AutoEncoderTopK-0/auxk_loss","AutoEncoderTopK-0/cossim","AutoEncoderTopK-0/dead_features","AutoEncoderTopK-0/effective_l0","AutoEncoderTopK-0/frac_alive","AutoEncoderTopK-0/frac_recovered","AutoEncoderTopK-0/frac_variance_explained","AutoEncoderTopK-0/l0","AutoEncoderTopK-0/l1_loss","AutoEncoderTopK-0/l2_loss","AutoEncoderTopK-0/l2_ratio","AutoEncoderTopK-0/loss","AutoEncoderTopK-0/loss_original","AutoEncoderTopK-0/loss_reconstructed","AutoEncoderTopK-0/loss_zero","AutoEncoderTopK-0/mse_loss","AutoEncoderTopK-1/auxk_loss","AutoEncoderTopK-1/cossim","AutoEncoderTopK-1/dead_features","AutoEncoderTopK-1/effective_l0","AutoEncoderTopK-1/frac_alive","AutoEncoderTopK-1/frac_recovered","AutoEncoderTopK-1/frac_variance_explained","AutoEncoderTopK-1/l0","AutoEncoderTopK-1/l1_loss","AutoEncoderTopK-1/l2_loss","AutoEncoderTopK-1/l2_ratio","AutoEncoderTopK-1/loss","AutoEncoderTopK-1/loss_original","AutoEncoderTopK-1/loss_reconstructed","AutoEncoderTopK-1/loss_zero","AutoEncoderTopK-1/mse_loss" 2 | "royal-water-4","finished","-","","","2024-07-14T00:29:11.000Z","45132","","768","0.03125","80000","cuda:7","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","128","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","768","0.03125","80000","cuda:7","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","192","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","2.7579562811297365e-7","0.9772376418113708","3","128","0.00004069010537932627","0.9967961311340332","0.9927451014518738","500.12890625","850.615234375","25.00290298461914","0.976984202861786","660.1929931640625","3.3303747177124023","3.375338315963745","17.364582061767578","661.279541015625","0.0000015463250520042493","0.9819841384887696","18","192","0.00004069010537932627","0.9979140758514404","0.990972638130188","654.5338134765625","987.9059448242188","22.222896575927734","0.9815428256988524","517.9833374023438","3.26992130279541","3.2998123168945312","17.59981346130371","522.2657470703125" 3 | "drawn-voice-3","finished","-","","","2024-07-14T00:28:27.000Z","50754","","768","0.03125","80000","cuda:4","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","8","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","768","0.03125","80000","cuda:4","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","16","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","0.00014197103155311197","0.9215346574783324","17839","8","0.00004069010537932627","0.9644190669059752","0.9756476283073424","1591.396484375","1681.16650390625","45.909427642822266","0.92127525806427","2218.64794921875","3.3303747177124023","3.829725503921509","17.364582061767578","2219.68701171875","0.00005811716982861981","0.9451953768730164","9701","16","0.00004069010537932627","0.9841692447662354","0.972999393939972","317.203857421875","534.530029296875","38.529930114746094","0.9450527429580688","1550.5992431640625","3.26992130279541","3.4967751502990723","17.59981346130371","1562.1041259765625" 4 | "still-river-1","finished","-","","","2024-07-14T00:28:26.000Z","45766","","768","0.03125","80000","cuda:5","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","32","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","768","0.03125","80000","cuda:5","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","48","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","0.00004375946446089074","0.9603254795074464","1118","32","0.00004069010537932627","0.9910413026809692","0.9874651432037354","201.1285400390625","508.3766479492187","32.90629577636719","0.9599528908729552","1139.227294921875","3.3303747177124023","3.4561030864715576","17.364582061767578","1142.533447265625","0.0000035865712106897263","0.9654753804206848","32","48","0.00004069010537932627","0.9935886263847352","0.9828585386276244","245.3511962890625","538.3258666992188","30.67569351196289","0.9650439023971558","987.57861328125","3.26992130279541","3.361794948577881","17.59981346130371","991.705810546875" 5 | "vivid-bee-2","finished","-","","","2024-07-14T00:28:26.000Z","44044","","768","0.03125","80000","cuda:6","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","64","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","768","0.03125","80000","cuda:6","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","96","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","8.166377369889233e-7","0.9692695140838624","8","64","0.00004069010537932627","0.9945665597915648","0.9902520775794984","302.2916259765625","629.1004638671875","29.00596618652344","0.968896210193634","887.1822509765625","3.3303747177124023","3.4066286087036133","17.364582061767578","888.516845703125","5.727280267819879e-7","0.9736075401306152","6","96","0.00004069010537932627","0.9957606792449952","0.9868393540382384","401.2947998046875","703.3333740234375","26.865930557250977","0.9731273651123048","756.0817260742188","3.26992130279541","3.3306703567504883","17.59981346130371","761.394287109375" -------------------------------------------------------------------------------- /results/primary/switch_sae_l0_lr.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/switch_sae_l0_lr.pdf -------------------------------------------------------------------------------- /results/primary/switch_sae_l0_mse.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/switch_sae_l0_mse.pdf -------------------------------------------------------------------------------- /results/primary/switch_sae_pareto.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/switch_sae_pareto.pdf -------------------------------------------------------------------------------- /results/primary/switch_sae_pareto_flop.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/switch_sae_pareto_flop.pdf -------------------------------------------------------------------------------- /results/primary/switch_sae_pareto_width.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/switch_sae_pareto_width.pdf -------------------------------------------------------------------------------- /save_activations.py: -------------------------------------------------------------------------------- 1 | #%% 2 | 3 | from nnsight import LanguageModel 4 | from dictionary_learning.utils import hf_dataset_to_generator 5 | from config import lm, activation_dim, layer, hf, n_ctxs 6 | import torch as t 7 | import einops 8 | from tqdm import tqdm 9 | import os 10 | 11 | t.set_grad_enabled(False) 12 | 13 | # %% 14 | 15 | 16 | device = f'cuda:1' 17 | model = LanguageModel(lm, dispatch=True, device_map=device) 18 | submodule = model.transformer.h[layer] 19 | data = hf_dataset_to_generator(hf) 20 | 21 | # %% 22 | batch_size = 256 23 | num_batches = 128 24 | ctx_len = 128 25 | 26 | total_tokens = batch_size * num_batches * ctx_len 27 | total_memory = total_tokens * activation_dim * 4 28 | print(f"Total contexts: {batch_size * num_batches / 1e3:.2f}K") 29 | print(f"Total tokens: {total_tokens / 1e6:.2f}M") 30 | print(f"Total memory: {total_memory / 1e9:.2f}GB") 31 | 32 | # %% 33 | 34 | # These functions copied from buffer.py 35 | 36 | def text_batch(): 37 | return [ 38 | next(data) for _ in range(batch_size) 39 | ] 40 | 41 | def tokenized_batch(): 42 | texts = text_batch() 43 | return model.tokenizer( 44 | texts, 45 | return_tensors='pt', 46 | max_length=ctx_len, 47 | padding=True, 48 | truncation=True 49 | ) 50 | 51 | def get_activations(input): 52 | with t.no_grad(): 53 | with model.trace(input): 54 | hidden_states = submodule.output.save() 55 | hidden_states = hidden_states.value 56 | if isinstance(hidden_states, tuple): 57 | hidden_states = hidden_states[0] 58 | hidden_states = hidden_states[input['attention_mask'] != 0] 59 | return hidden_states 60 | 61 | 62 | 63 | # %% 64 | 65 | all_activations = [] 66 | all_tokens = [] 67 | 68 | for _ in tqdm(range(num_batches)): 69 | batch = tokenized_batch() 70 | all_tokens.append(batch['input_ids'].cpu()) 71 | activations = get_activations(batch) 72 | activations = einops.rearrange(activations, "(b c) d -> b c d", b=batch_size) 73 | all_activations.append(activations.cpu()) 74 | 75 | # %% 76 | 77 | concatenated_activations = t.cat(all_activations) 78 | concatenated_tokens = t.cat(all_tokens) 79 | print(concatenated_activations.shape, concatenated_tokens.shape) 80 | 81 | # %% 82 | 83 | # save activations 84 | os.makedirs('data', exist_ok=True) 85 | t.save(concatenated_activations, f'data/gpt2_activations_layer{layer}.pt') 86 | t.save(concatenated_tokens, f'data/gpt2_tokens.pt') -------------------------------------------------------------------------------- /scaling_laws/attempt0/train-dense-topk.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trains dense TopK SAEs of varying scale. 3 | """ 4 | 5 | import os 6 | import sys 7 | sys.path.append( # the switch_sae directory 8 | os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 | ) 10 | os.environ['HF_HOME'] = '/om/user/ericjm/.cache/huggingface' 11 | 12 | 13 | from nnsight import LanguageModel 14 | import torch as t 15 | from dictionary_learning import ActivationBuffer 16 | from dictionary_learning.training import trainSAE 17 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename 18 | from dictionary_learning.trainers.top_k import AutoEncoderTopK, TrainerTopK 19 | from dictionary_learning.evaluation import evaluate 20 | import wandb 21 | 22 | 23 | lm = 'openai-community/gpt2' 24 | activation_dim = 768 25 | layer = 8 26 | hf = 'Skylion007/openwebtext' 27 | steps = 100_000 # test run! 28 | n_ctxs = int(1e4) 29 | 30 | dict_sizes = [2048, 4096, 8192, 16384, 32768, 65536, 131072] 31 | k = 32 32 | 33 | save_dir = "topk_dense" 34 | 35 | device = f'cuda:0' 36 | model = LanguageModel(lm, dispatch=True, device_map=device) 37 | submodule = model.transformer.h[layer] 38 | data = hf_dataset_to_generator(hf) 39 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device) 40 | 41 | 42 | base_trainer_config = { 43 | 'trainer' : TrainerTopK, 44 | 'dict_class' : AutoEncoderTopK, 45 | 'activation_dim' : activation_dim, 46 | 'k' : k, 47 | 'auxk_alpha' : 1/32, 48 | 'decay_start' : int(steps * 0.8), 49 | 'steps' : steps, 50 | 'seed' : 0, 51 | 'device' : device, 52 | 'layer' : layer, 53 | 'lm_name' : lm, 54 | 'wandb_name' : 'AutoEncoderTopK' 55 | } 56 | 57 | trainer_configs = [(base_trainer_config | {'dict_size': ds}) for ds in dict_sizes] 58 | 59 | wandb.init(entity="ericjmichaud_", project="switch_saes_scaling_laws_attempt_0", config={f"{trainer_config['wandb_name']}-{i}" : trainer_config for i, trainer_config in enumerate(trainer_configs)}) 60 | 61 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir=save_dir, log_steps=1, steps=steps) 62 | 63 | print("Training finished. Evaluating SAE...", flush=True) 64 | for i, trainer_config in enumerate(trainer_configs): 65 | ae = AutoEncoderTopK.from_pretrained( 66 | os.path.join(save_dir, f'{cfg_filename(trainer_config)}/ae.pt'), 67 | k=trainer_config['k'], 68 | device=device 69 | ) 70 | metrics = evaluate(ae, buffer, device=device) 71 | log = {} 72 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()}) 73 | wandb.log(log, step=steps+1) 74 | wandb.finish() 75 | -------------------------------------------------------------------------------- /scaling_laws/attempt0/train-dense-topk.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=dtk0 3 | #SBATCH --partition=tegmark 4 | #SBATCH --ntasks=1 5 | #SBATCH --mem=16GB 6 | #SBATCH --gres=gpu:a100:1 7 | #SBATCH --time=2-00:00:00 8 | #SBATCH --output=/om2/user/ericjm/switch_sae/scaling_laws/attempt0/slurm-%j.out 9 | 10 | python /om2/user/ericjm/switch_sae/scaling_laws/attempt0/train-dense-topk.py 11 | -------------------------------------------------------------------------------- /scaling_laws/attempt0/train-switch-topk.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trains Switch SAEs of varying scale. 3 | """ 4 | 5 | import os 6 | import sys 7 | sys.path.append( # the switch_sae directory 8 | os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 9 | ) 10 | os.environ['HF_HOME'] = '/om/user/ericjm/.cache/huggingface' 11 | 12 | 13 | from nnsight import LanguageModel 14 | import torch as t 15 | from dictionary_learning import ActivationBuffer 16 | from dictionary_learning.training import trainSAE 17 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename, str2bool 18 | from dictionary_learning.trainers.switch import SwitchAutoEncoder, SwitchTrainer 19 | from dictionary_learning.evaluation import evaluate 20 | import wandb 21 | import argparse 22 | # import itertools 23 | # from config import lm, activation_dim, layer, hf, steps, n_ctxs 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--num_experts", type=int, default=8, required=True) 27 | parser.add_argument("--lb_alpha", type=float, default=3.0) 28 | parser.add_argument("--heaviside", type=str2bool, default=False) 29 | args = parser.parse_args() 30 | 31 | lm = 'openai-community/gpt2' 32 | activation_dim = 768 33 | layer = 8 34 | hf = 'Skylion007/openwebtext' 35 | steps = 100_000 36 | n_ctxs = int(1e4) 37 | 38 | dict_sizes = [2048, 4096, 8192, 16384, 32768, 65536, 131072] 39 | k = 32 40 | 41 | save_dir = f"topk_switch{args.num_experts}" 42 | 43 | device = f'cuda:0' 44 | model = LanguageModel(lm, dispatch=True, device_map=device) 45 | submodule = model.transformer.h[layer] 46 | data = hf_dataset_to_generator(hf) 47 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device) 48 | 49 | base_trainer_config = { 50 | 'trainer' : SwitchTrainer, 51 | 'dict_class' : SwitchAutoEncoder, 52 | 'activation_dim' : activation_dim, 53 | 'k': k, 54 | 'experts' : args.num_experts, 55 | 'lb_alpha' : args.lb_alpha, 56 | 'heaviside' : args.heaviside, 57 | 'auxk_alpha' : 1/32, 58 | 'decay_start' : int(steps * 0.8), 59 | 'steps' : steps, 60 | 'seed' : 0, 61 | 'device' : device, 62 | 'layer' : layer, 63 | 'lm_name' : lm, 64 | 'wandb_name' : 'SwitchAutoEncoder' 65 | } 66 | 67 | trainer_configs = [(base_trainer_config | {'dict_size': ds}) for ds in dict_sizes] 68 | # {'k': combo[0], 'experts': combo[1], 'heaviside': combo[2], 'lb_alpha': combo[3]}) for combo in itertools.product(args.ks, args.num_experts, args.heavisides, args.lb_alphas)] 69 | 70 | wandb.init(entity="ericjmichaud_", project="switch_saes_scaling_laws_attempt_0", config={f"{trainer_config['wandb_name']}-{i}" : trainer_config for i, trainer_config in enumerate(trainer_configs)}) 71 | # wandb.init(entity="amudide", project="Switch (LB)", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)}) 72 | 73 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir=save_dir, log_steps=1, steps=steps) 74 | 75 | print("Training finished. Evaluating SAE...", flush=True) 76 | for i, trainer_config in enumerate(trainer_configs): 77 | ae = SwitchAutoEncoder.from_pretrained( 78 | os.path.join(save_dir, f'{cfg_filename(trainer_config)}/ae.pt'), 79 | k = trainer_config['k'], 80 | experts = trainer_config['experts'], 81 | heaviside = trainer_config['heaviside'], 82 | device=device 83 | ) 84 | metrics = evaluate(ae, buffer, device=device) 85 | log = {} 86 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()}) 87 | wandb.log(log, step=steps+1) 88 | wandb.finish() 89 | -------------------------------------------------------------------------------- /scaling_laws/attempt0/train-switch-topk.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=dtk0 3 | #SBATCH --partition=tegmark 4 | #SBATCH --ntasks=1 5 | #SBATCH --mem=16GB 6 | #SBATCH --gres=gpu:a100:1 7 | #SBATCH --time=2-00:00:00 8 | #SBATCH --output=/om2/user/ericjm/switch_sae/scaling_laws/attempt0/slurm-%A_%a.out 9 | #SBATCH --array=8,64 10 | 11 | python /om2/user/ericjm/switch_sae/scaling_laws/attempt0/train-switch-topk.py --num_experts $SLURM_ARRAY_TASK_ID 12 | -------------------------------------------------------------------------------- /table/create_table_script.py: -------------------------------------------------------------------------------- 1 | # %% 2 | 3 | import itertools 4 | 5 | # Generate all combinations of parameters 6 | layers = [2, 4, 6, 8, 10] 7 | sae_types = ['switch', 'topk'] 8 | types = ['resid', 'attn', 'mlp'] 9 | devices = [f'cuda:{i}' for i in range(8)] 10 | 11 | # First set of commands for GPT-2 12 | gpt2_commands = [] 13 | for layer, sae_type, type_ in itertools.product(layers, sae_types, types): 14 | cmd = f"python3 table/train_switch_table.py --device cuda:{len(gpt2_commands) % 8} --layer {layer} --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type {type_} --sae_type {sae_type} --steps 20000 &" 15 | gpt2_commands.append(cmd) 16 | 17 | # Second set of commands for Gemma-2B 18 | gemma_commands = [] 19 | for i, sae_type in enumerate(sae_types): 20 | cmd = f"python3 table/train_switch_table.py --device cuda:{i} --layer 12 --lm google/gemma-2b --ks 64 --activation_dim 2048 --dict_ratio 32 --num_experts 8 --type resid --sae_type {sae_type} --steps 2000000 &" 21 | gemma_commands.append(cmd) 22 | 23 | # Write commands to a bash script 24 | with open('run_parallel.sh', 'w') as f: 25 | f.write('#!/bin/bash\n\n') 26 | 27 | # Write GPT-2 commands in batches of 8 28 | f.write('# GPT-2 training commands\n') 29 | for i, cmd in enumerate(gpt2_commands): 30 | f.write(f'{cmd}\n') 31 | if (i + 1) % 8 == 0: 32 | f.write('\nwait\n\n') # Wait after every 8 commands 33 | 34 | # If there are remaining commands that don't make a full batch of 8 35 | if len(gpt2_commands) % 8: 36 | f.write('\nwait\n\n') 37 | 38 | # Write Gemma commands 39 | f.write('# Gemma-2B training commands\n') 40 | for i, cmd in enumerate(gemma_commands): 41 | f.write(f'{cmd}\n') 42 | 43 | f.write('\nwait\n') # Wait for all commands to finish 44 | 45 | # Make the script executable 46 | import os 47 | os.chmod('run_parallel.sh', 0o755) 48 | 49 | print(f"Created run_parallel.sh with {len(gpt2_commands)} GPT-2 commands and {len(gemma_commands)} Gemma commands") 50 | 51 | -------------------------------------------------------------------------------- /table/eval_table.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import sys 4 | import pandas as pd 5 | 6 | # Add the parent directory to the sys path 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | 9 | 10 | import itertools 11 | import os 12 | from dictionary_learning.trainers.top_k import AutoEncoderTopK, TrainerTopK 13 | from nnsight import LanguageModel 14 | import torch as t 15 | from dictionary_learning import ActivationBuffer 16 | from dictionary_learning.training import trainSAE 17 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename, str2bool 18 | from dictionary_learning.trainers.switch import SwitchAutoEncoder, SwitchTrainer 19 | from modified_eval import evaluate 20 | import wandb 21 | import argparse 22 | import itertools 23 | from config import hf 24 | 25 | base_path = "/home/jengels/switch_sae/other_dictionaries" 26 | 27 | sae_types = ['switch', 'topk'] 28 | 29 | # %% 30 | 31 | def get_model_metrics(model_name, layers_to_eval, layer_types, sae_types, info_to_filename, d_submodule, device="cuda:0"): 32 | model = LanguageModel(model_name, dispatch=True, device_map=device) 33 | info_to_metrics = {} 34 | 35 | for layer in layers_to_eval: 36 | for layer_type in layer_types: 37 | values = [] 38 | 39 | if model_name == "google/gemma-2b": 40 | submodule = model.model.layers[layer] 41 | else: 42 | submodule = model.transformer.h[layer] 43 | 44 | if layer_type == "resid": 45 | submodule = submodule 46 | elif layer_type == "mlp": 47 | submodule = submodule.mlp 48 | elif layer_type == "attn": 49 | submodule = submodule.attn 50 | 51 | for sae_type in sae_types: 52 | 53 | data = hf_dataset_to_generator(hf, split="train") 54 | buffer = ActivationBuffer( 55 | data, 56 | model, 57 | submodule, 58 | d_submodule=d_submodule, 59 | n_ctxs=1e2, 60 | device="cpu", 61 | out_batch_size=8192, 62 | refresh_batch_size=1 if model_name == "google/gemma-2b" else 128, 63 | ) 64 | 65 | filename = info_to_filename[(layer, sae_type, layer_type)] 66 | 67 | if model_name == "google/gemma-2b": 68 | model_path = f"{filename}/checkpoints/ae_{gemma_checkpoint_id}.pt" 69 | else: 70 | model_path = f"{filename}/ae.pt" 71 | 72 | if sae_type == "switch": 73 | ae = SwitchAutoEncoder.from_pretrained( 74 | model_path, 75 | k=64, 76 | experts=8, 77 | heaviside=False, 78 | device=device, 79 | ) 80 | metrics = evaluate(ae, 81 | buffer, 82 | device=device, 83 | batch_size=1 if model_name == "google/gemma-2b" else 128, 84 | num_batches=32 if model_name == "google/gemma-2b" else 2) 85 | else: 86 | ae = AutoEncoderTopK.from_pretrained( 87 | model_path, 88 | k=64, 89 | device=device, 90 | ) 91 | metrics = evaluate(ae, 92 | buffer, 93 | device=device, 94 | batch_size=1 if model_name == "google/gemma-2b" else 128, 95 | num_batches=32 if model_name == "google/gemma-2b" else 2) 96 | 97 | values.append(metrics) 98 | print(metrics) 99 | 100 | key = (layer, layer_type) 101 | info_to_metrics[key] = values 102 | 103 | del model 104 | return info_to_metrics 105 | 106 | # %% 107 | 108 | 109 | # Generate all combinations of parameters 110 | layers = [2, 4, 8, 10] 111 | types = ['resid', 'attn', 'mlp'] 112 | devices = [f'cuda:{i}' for i in range(8)] 113 | 114 | # First set of commands for GPT-2 115 | gpt_info_to_filename = {} 116 | for layer, sae_type, layer_type in itertools.product(layers, sae_types, types): 117 | 118 | device_num = len(gpt_info_to_filename) % 8 119 | if sae_type == 'switch': 120 | filename = f"dict_class:.SwitchAutoEncoder'>_activation_dim:768_dict_size:24576_auxk_alpha:0.03125_decay_start:16000_steps:20000_seed:0_device:cuda:{device_num}_layer:{layer}_lm_name:openai-communitygpt2_wandb_name:SwitchAutoEncoder_k:64_experts:8_heaviside:False_lb_alpha:3" 121 | else: 122 | filename = f"dict_class:_k.AutoEncoderTopK'>_activation_dim:768_dict_size:24576_auxk_alpha:0.03125_decay_start:16000_steps:20000_seed:0_device:cuda:{device_num}_layer:{layer}_lm_name:openai-communitygpt2_wandb_name:AutoEncoderTopK_k:64" 123 | 124 | filename = os.path.join(base_path, filename) 125 | 126 | 127 | if not os.path.exists(filename): 128 | print(layer, sae_type, layer_type) 129 | continue 130 | 131 | gpt_info_to_filename[(layer, sae_type,layer_type)] = filename 132 | 133 | 134 | # Get GPT-2 metrics 135 | gpt_info_to_metrics = get_model_metrics( 136 | model_name="openai-community/gpt2", 137 | layers_to_eval=layers, 138 | layer_types=types, 139 | sae_types=sae_types, 140 | info_to_filename=gpt_info_to_filename, 141 | d_submodule=768, 142 | device="cuda:1" 143 | ) 144 | 145 | 146 | # %% 147 | # Get Gemma metrics 148 | gemma_info_to_metrics = {} 149 | gemma_checkpoint_id = 2000 150 | 151 | # First create filename mapping for Gemma 152 | gemma_info_to_filename = {} 153 | for i, sae_type in enumerate(sae_types): 154 | layer = 12 155 | device = f"cuda:{i}" 156 | if sae_type == "switch": 157 | filename = f"{base_path}/dict_class:.SwitchAutoEncoder'>_activation_dim:2048_dict_size:65536_auxk_alpha:0.03125_decay_start:1600000_steps:2000000_seed:0_device:cuda:{i}_layer:12_lm_name:googlegemma-2b_wandb_name:SwitchAutoEncoder_k:64_experts:8_heaviside:False_lb_alpha:3" 158 | else: 159 | filename = f"{base_path}/dict_class:_k.AutoEncoderTopK'>_activation_dim:2048_dict_size:65536_auxk_alpha:0.03125_decay_start:1600000_steps:2000000_seed:0_device:cuda:{i}_layer:12_lm_name:googlegemma-2b_wandb_name:AutoEncoderTopK_k:64" 160 | 161 | gemma_info_to_filename[(layer, sae_type, "resid")] = filename 162 | 163 | gemma_info_to_metrics = get_model_metrics( 164 | model_name="google/gemma-2b", 165 | layers_to_eval=[12], 166 | layer_types=["resid"], 167 | sae_types=sae_types, 168 | info_to_filename=gemma_info_to_filename, 169 | d_submodule=2048, 170 | device="cuda:1" 171 | ) 172 | 173 | 174 | 175 | 176 | 177 | # %% 178 | 179 | 180 | # %% 181 | rows = [] 182 | # Add GPT-2 results 183 | for (layer, layer_type), values in gpt_info_to_metrics.items(): 184 | row = { 185 | 'Model': 'GPT-2', 186 | 'Layer': layer, 187 | 'Type': layer_type, 188 | 'TopK FVE': f"{values[1]['frac_variance_explained']:.3f}", 189 | 'Switch FVE': f"{values[0]['frac_variance_explained']:.3f}", 190 | 'TopK FR': f"{values[1]['frac_recovered']:.3f}", 191 | 'Switch FR': f"{values[0]['frac_recovered']:.3f}" 192 | } 193 | rows.append(row) 194 | 195 | # Add Gemma results 196 | for (layer, layer_type), values in gemma_info_to_metrics.items(): 197 | row = { 198 | 'Model': 'Gemma', 199 | 'Layer': layer, 200 | 'Type': layer_type, 201 | 'TopK FVE': f"{values[1]['frac_variance_explained']:.3f}", 202 | 'Switch FVE': f"{values[0]['frac_variance_explained']:.3f}", 203 | 'TopK FR': f"{values[1]['frac_recovered']:.3f}", 204 | 'Switch FR': f"{values[0]['frac_recovered']:.3f}" 205 | } 206 | rows.append(row) 207 | 208 | df = pd.DataFrame(rows) 209 | print("\nResults Table:") 210 | print(df.to_markdown(index=False)) 211 | 212 | # %% 213 | 214 | df.to_csv("results.csv", index=False) 215 | 216 | # %% 217 | 218 | # Convert to LaTeX table 219 | latex_table = """\\begin{table}[h] 220 | \\centering 221 | \\begin{tabular}{lcccccccc} 222 | \\toprule 223 | Model & Layer & Type & TopK FVE & Switch FVE & TopK FR & Switch FR \\\\ 224 | \\midrule""" 225 | 226 | for _, row in df.iterrows(): 227 | latex_table += f"\n{row['Model']} & {row['Layer']} & {row['Type']} & {row['TopK FVE']} & {row['Switch FVE']} & {row['TopK FR']} & {row['Switch FR']} \\\\" 228 | 229 | latex_table += """ 230 | \\bottomrule 231 | \\end{tabular} 232 | \\caption{Comparison of TopK and Switch autoencoders across different models, layers and component types. FVE = Fraction of Variance Explained, FR = Fraction Recovered.} 233 | \\label{tab:model_comparison} 234 | \\end{table}""" 235 | 236 | print("\nLaTeX Table:") 237 | print(latex_table) 238 | -------------------------------------------------------------------------------- /table/modified_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for evaluating dictionaries on a model and dataset. 3 | """ 4 | 5 | import os 6 | import sys 7 | 8 | # Add the parent directory to the sys path 9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 10 | 11 | 12 | import torch as t 13 | from tqdm import tqdm 14 | from dictionary_learning.buffer import ActivationBuffer, NNsightActivationBuffer 15 | from nnsight import LanguageModel 16 | from dictionary_learning.config import DEBUG 17 | 18 | 19 | def loss_recovered( 20 | text, # a batch of text 21 | model: LanguageModel, # an nnsight LanguageModel 22 | submodule, # submodules of model 23 | dictionary, # dictionaries for submodules 24 | max_len=None, # max context length for loss recovered 25 | normalize_batch=False, # normalize batch before passing through dictionary 26 | io="out", # can be 'in', 'out', or 'in_and_out' 27 | tracer_args = {'use_cache': False, 'output_attentions': False}, # minimize cache during model trace. 28 | ): 29 | """ 30 | How much of the model's loss is recovered by replacing the component output 31 | with the reconstruction by the autoencoder? 32 | """ 33 | 34 | if max_len is None: 35 | invoker_args = {} 36 | else: 37 | invoker_args = {"truncation": True, "max_length": max_len } 38 | 39 | # unmodified logits 40 | with model.trace(text, invoker_args=invoker_args): 41 | logits_original = model.output.save() 42 | logits_original = logits_original.value 43 | 44 | # logits when replacing component activations with reconstruction by autoencoder 45 | with model.trace(text, **tracer_args, invoker_args=invoker_args): 46 | if io == 'in': 47 | x = submodule.input[0] 48 | if type(submodule.input.shape) == tuple: x = x[0] 49 | if normalize_batch: 50 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 51 | x = x * scale 52 | elif io == 'out': 53 | x = submodule.output 54 | if type(submodule.output.shape) == tuple: x = x[0] 55 | if normalize_batch: 56 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 57 | x = x * scale 58 | elif io == 'in_and_out': 59 | x = submodule.input[0] 60 | if type(submodule.input.shape) == tuple: x = x[0] 61 | print(f'x.shape: {x.shape}') 62 | if normalize_batch: 63 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 64 | x = x * scale 65 | else: 66 | raise ValueError(f"Invalid value for io: {io}") 67 | x = x.save() 68 | 69 | # pull this out so dictionary can be written without FakeTensor (top_k needs this) 70 | x_hat = dictionary(x.view(-1, x.shape[-1])).view(x.shape) 71 | 72 | # intervene with `x_hat` 73 | with model.trace(text, **tracer_args, invoker_args=invoker_args): 74 | if io == 'in': 75 | x = submodule.input[0] 76 | if normalize_batch: 77 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 78 | x_hat = x_hat / scale 79 | if type(submodule.input.shape) == tuple: 80 | submodule.input[0][:] = x_hat 81 | else: 82 | submodule.input = x_hat 83 | elif io == 'out': 84 | x = submodule.output 85 | if normalize_batch: 86 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 87 | x_hat = x_hat / scale 88 | if type(submodule.output.shape) == tuple: 89 | submodule.output = (x_hat,) 90 | else: 91 | submodule.output = x_hat 92 | elif io == 'in_and_out': 93 | x = submodule.input[0] 94 | if normalize_batch: 95 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean() 96 | x_hat = x_hat / scale 97 | submodule.output = x_hat 98 | else: 99 | raise ValueError(f"Invalid value for io: {io}") 100 | 101 | logits_reconstructed = model.output.save() 102 | logits_reconstructed = logits_reconstructed.value 103 | 104 | # logits when replacing component activations with zeros 105 | with model.trace(text, **tracer_args, invoker_args=invoker_args): 106 | if io == 'in': 107 | x = submodule.input[0] 108 | if type(submodule.input.shape) == tuple: 109 | submodule.input[0][:] = t.zeros_like(x[0]) 110 | else: 111 | submodule.input = t.zeros_like(x) 112 | elif io in ['out', 'in_and_out']: 113 | x = submodule.output 114 | if type(submodule.output.shape) == tuple: 115 | submodule.output[0][:] = t.zeros_like(x[0]) 116 | else: 117 | submodule.output = t.zeros_like(x) 118 | else: 119 | raise ValueError(f"Invalid value for io: {io}") 120 | 121 | input = model.input.save() 122 | logits_zero = model.output.save() 123 | logits_zero = logits_zero.value 124 | 125 | # get everything into the right format 126 | try: 127 | logits_original = logits_original.logits 128 | logits_reconstructed = logits_reconstructed.logits 129 | logits_zero = logits_zero.logits 130 | except: 131 | pass 132 | 133 | if isinstance(text, t.Tensor): 134 | tokens = text 135 | else: 136 | try: 137 | tokens = input[1]['input_ids'] 138 | except: 139 | tokens = input[1]['input'] 140 | 141 | # compute losses 142 | losses = [] 143 | if hasattr(model, 'tokenizer') and model.tokenizer is not None: 144 | loss_kwargs = {'ignore_index': model.tokenizer.pad_token_id} 145 | else: 146 | loss_kwargs = {} 147 | for logits in [logits_original, logits_reconstructed, logits_zero]: 148 | loss = t.nn.CrossEntropyLoss(**loss_kwargs)( 149 | logits[:, :-1, :].reshape(-1, logits.shape[-1]), tokens[:, 1:].reshape(-1) 150 | ) 151 | losses.append(loss) 152 | 153 | return tuple(losses) 154 | 155 | 156 | def evaluate( 157 | dictionary, # a dictionary 158 | activations, # a generator of activations; if an ActivationBuffer, also compute loss recovered 159 | max_len=128, # max context length for loss recovered 160 | batch_size=8, # batch size for loss recovered 161 | num_batches=100, # number of batches to evaluate, 162 | io="out", # can be 'in', 'out', or 'in_and_out' 163 | normalize_batch=False, # normalize batch before passing through dictionary 164 | tracer_args={'use_cache': False, 'output_attentions': False}, # minimize cache during model trace. 165 | device="cpu", 166 | ): 167 | with t.no_grad(): 168 | metrics_lists = { 169 | "l2_loss": [], "l1_loss": [], "mse_loss": [], "l0": [], 170 | "frac_alive": [], "frac_variance_explained": [], "cossim": [], 171 | "l2_ratio": [], "frac_recovered": [] 172 | } 173 | 174 | for _ in tqdm(range(num_batches)): 175 | try: 176 | x = next(activations).to(device) 177 | if normalize_batch: 178 | x = x / x.norm(dim=-1).mean() * (dictionary.activation_dim ** 0.5) 179 | 180 | except StopIteration: 181 | raise StopIteration( 182 | "Not enough activations in buffer. Pass a buffer with a smaller batch size or more data." 183 | ) 184 | 185 | x_hat, f = dictionary(x, output_features=True) 186 | l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean() 187 | mse_loss = (x - x_hat).pow(2).sum(dim=-1).mean() 188 | l1_loss = f.norm(p=1, dim=-1).mean() 189 | l0 = (f != 0).float().sum(dim=-1).mean() 190 | frac_alive = t.flatten(f, start_dim=0, end_dim=1).any(dim=0).sum() / dictionary.dict_size 191 | 192 | # cosine similarity between x and x_hat 193 | x_normed = x / t.linalg.norm(x, dim=-1, keepdim=True) 194 | x_hat_normed = x_hat / t.linalg.norm(x_hat, dim=-1, keepdim=True) 195 | cossim = (x_normed * x_hat_normed).sum(dim=-1).mean() 196 | 197 | # l2 ratio 198 | l2_ratio = (t.linalg.norm(x_hat, dim=-1) / t.linalg.norm(x, dim=-1)).mean() 199 | 200 | #compute variance explained 201 | total_variance = t.var(x, dim=0).sum() 202 | residual_variance = t.var(x - x_hat, dim=0).sum() 203 | frac_variance_explained = (1 - residual_variance / total_variance) 204 | 205 | # compute loss recovered 206 | loss_original, loss_reconstructed, loss_zero = loss_recovered( 207 | activations.text_batch(batch_size=batch_size), 208 | activations.model, 209 | activations.submodule, 210 | dictionary, 211 | max_len=max_len, 212 | normalize_batch=normalize_batch, 213 | io=io, 214 | tracer_args=tracer_args 215 | ) 216 | frac_recovered = (loss_reconstructed - loss_zero) / (loss_original - loss_zero) 217 | 218 | metrics_lists["l2_loss"].append(l2_loss.item()) 219 | metrics_lists["l1_loss"].append(l1_loss.item()) 220 | metrics_lists["mse_loss"].append(mse_loss.item()) 221 | metrics_lists["l0"].append(l0.item()) 222 | metrics_lists["frac_alive"].append(frac_alive.item()) 223 | metrics_lists["frac_variance_explained"].append(frac_variance_explained.item()) 224 | metrics_lists["cossim"].append(cossim.item()) 225 | metrics_lists["l2_ratio"].append(l2_ratio.item()) 226 | metrics_lists["frac_recovered"].append(frac_recovered.item()) 227 | 228 | out = {k: sum(v)/len(v) for k, v in metrics_lists.items()} 229 | 230 | return out -------------------------------------------------------------------------------- /table/run_parallel.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # GPT-2 training commands 4 | python3 table/train_switch_table.py --device cuda:0 --layer 2 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type switch --steps 20000 & 5 | python3 table/train_switch_table.py --device cuda:1 --layer 2 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type switch --steps 20000 & 6 | python3 table/train_switch_table.py --device cuda:2 --layer 2 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type switch --steps 20000 & 7 | python3 table/train_switch_table.py --device cuda:3 --layer 2 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type topk --steps 20000 & 8 | python3 table/train_switch_table.py --device cuda:4 --layer 2 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type topk --steps 20000 & 9 | python3 table/train_switch_table.py --device cuda:5 --layer 2 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type topk --steps 20000 & 10 | python3 table/train_switch_table.py --device cuda:6 --layer 4 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type switch --steps 20000 & 11 | python3 table/train_switch_table.py --device cuda:7 --layer 4 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type switch --steps 20000 & 12 | 13 | wait 14 | 15 | python3 table/train_switch_table.py --device cuda:0 --layer 4 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type switch --steps 20000 & 16 | python3 table/train_switch_table.py --device cuda:1 --layer 4 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type topk --steps 20000 & 17 | python3 table/train_switch_table.py --device cuda:2 --layer 4 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type topk --steps 20000 & 18 | python3 table/train_switch_table.py --device cuda:3 --layer 4 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type topk --steps 20000 & 19 | python3 table/train_switch_table.py --device cuda:4 --layer 8 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type switch --steps 20000 & 20 | python3 table/train_switch_table.py --device cuda:5 --layer 8 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type switch --steps 20000 & 21 | python3 table/train_switch_table.py --device cuda:6 --layer 8 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type switch --steps 20000 & 22 | python3 table/train_switch_table.py --device cuda:7 --layer 8 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type topk --steps 20000 & 23 | 24 | wait 25 | 26 | python3 table/train_switch_table.py --device cuda:0 --layer 8 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type topk --steps 20000 & 27 | python3 table/train_switch_table.py --device cuda:1 --layer 8 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type topk --steps 20000 & 28 | python3 table/train_switch_table.py --device cuda:2 --layer 10 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type switch --steps 20000 & 29 | python3 table/train_switch_table.py --device cuda:3 --layer 10 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type switch --steps 20000 & 30 | python3 table/train_switch_table.py --device cuda:4 --layer 10 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type switch --steps 20000 & 31 | python3 table/train_switch_table.py --device cuda:5 --layer 10 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type topk --steps 20000 & 32 | python3 table/train_switch_table.py --device cuda:6 --layer 10 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type topk --steps 20000 & 33 | python3 table/train_switch_table.py --device cuda:7 --layer 10 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type topk --steps 20000 & 34 | 35 | wait 36 | 37 | python3 table/train_switch_table.py --device cuda:0 --layer 12 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type switch --steps 20000 & 38 | python3 table/train_switch_table.py --device cuda:1 --layer 12 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type switch --steps 20000 & 39 | python3 table/train_switch_table.py --device cuda:2 --layer 12 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type switch --steps 20000 & 40 | python3 table/train_switch_table.py --device cuda:3 --layer 12 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type topk --steps 20000 & 41 | python3 table/train_switch_table.py --device cuda:4 --layer 12 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type topk --steps 20000 & 42 | python3 table/train_switch_table.py --device cuda:5 --layer 12 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type topk --steps 20000 & 43 | 44 | wait 45 | 46 | # Gemma-2B training commands 47 | python3 table/train_switch_table.py --device cuda:0 --layer 12 --lm google/gemma-2b --ks 64 --activation_dim 2048 --dict_ratio 32 --num_experts 8 --type resid --sae_type switch --steps 2000000 & 48 | python3 table/train_switch_table.py --device cuda:1 --layer 12 --lm google/gemma-2b --ks 64 --activation_dim 2048 --dict_ratio 32 --num_experts 8 --type resid --sae_type topk --steps 2000000 & 49 | 50 | wait 51 | -------------------------------------------------------------------------------- /table/train_switch_table.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | import sys 4 | 5 | # Add the parent directory to the sys path 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | 8 | 9 | from dictionary_learning.trainers.top_k import AutoEncoderTopK, TrainerTopK 10 | from nnsight import LanguageModel 11 | import torch as t 12 | from dictionary_learning import ActivationBuffer 13 | from dictionary_learning.training import trainSAE 14 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename, str2bool 15 | from dictionary_learning.trainers.switch import SwitchAutoEncoder, SwitchTrainer 16 | from dictionary_learning.evaluation import evaluate 17 | import wandb 18 | import argparse 19 | import itertools 20 | from config import hf 21 | 22 | 23 | # %% 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--device", type=str, default="cuda:1", help="Device to run on") 26 | parser.add_argument("--lm", type=str, help="Language model to use") 27 | parser.add_argument( 28 | "--layer", type=int, default=12, help="Layer to extract activations from" 29 | ) 30 | parser.add_argument( 31 | "--type", type=str, default="resid", choices=["resid", "mlp", "attn"] 32 | ) 33 | parser.add_argument( 34 | "--sae_type", type=str, default="switch", choices=["switch", "topk"] 35 | ) 36 | parser.add_argument("--ks", nargs="+", type=int, default=[64], help="List of k values") 37 | parser.add_argument( 38 | "--activation_dim", type=int, default=2048, help="Dimension of activations" 39 | ) 40 | parser.add_argument("--dict_ratio", type=int, default=32, help="Dictionary size ratio") 41 | parser.add_argument( 42 | "--num_experts", nargs="+", type=int, default=[8], help="List of number of experts" 43 | ) 44 | parser.add_argument( 45 | "--steps", type=int, default=10000, help="Number of steps to train for" 46 | ) 47 | 48 | args = parser.parse_args() 49 | 50 | device = args.device 51 | layer = args.layer 52 | lm = args.lm 53 | print(lm) 54 | ks = args.ks 55 | activation_dim = args.activation_dim 56 | dict_ratio = args.dict_ratio 57 | num_experts = args.num_experts 58 | lb_alphas = [3] 59 | heavisides = [False] 60 | n_ctxs = 3e4 61 | batch_size = 8192 62 | steps = args.steps 63 | 64 | # %% 65 | 66 | model = LanguageModel(lm, dispatch=True, device_map=device) 67 | 68 | # %% 69 | 70 | if lm == "openai-community/gpt2": 71 | 72 | if args.type == "resid": 73 | submodule = model.transformer.h[layer] 74 | elif args.type == "mlp": 75 | submodule = model.transformer.h[layer].mlp 76 | elif args.type == "attn": 77 | submodule = model.transformer.h[layer].attn 78 | 79 | else: 80 | 81 | if args.type == "resid": 82 | submodule = model.model.layers[layer] 83 | elif args.type == "mlp": 84 | submodule = model.model.layers[layer].mlp 85 | elif args.type == "attn": 86 | submodule = model.model.layers[layer].self_attn 87 | 88 | # %% 89 | 90 | 91 | data = hf_dataset_to_generator(hf) 92 | buffer = ActivationBuffer( 93 | data, 94 | model, 95 | submodule, 96 | d_submodule=activation_dim, 97 | n_ctxs=n_ctxs, 98 | device="cpu", 99 | out_batch_size=batch_size, 100 | refresh_batch_size=512 if lm == "openai-community/gpt2" else 64, 101 | ) 102 | 103 | if args.sae_type == "switch": 104 | 105 | base_trainer_config = { 106 | "trainer": SwitchTrainer, 107 | "dict_class": SwitchAutoEncoder, 108 | "activation_dim": activation_dim, 109 | "dict_size": dict_ratio * activation_dim, 110 | "auxk_alpha": 1 / 32, 111 | "decay_start": int(steps * 0.8), 112 | "steps": steps, 113 | "seed": 0, 114 | "device": device, 115 | "layer": layer, 116 | "lm_name": lm, 117 | "wandb_name": "SwitchAutoEncoder", 118 | } 119 | 120 | trainer_configs = [ 121 | ( 122 | base_trainer_config 123 | | { 124 | "k": combo[0], 125 | "experts": combo[1], 126 | "heaviside": combo[2], 127 | "lb_alpha": combo[3], 128 | } 129 | ) 130 | for combo in itertools.product(ks, num_experts, heavisides, lb_alphas) 131 | ] 132 | else: 133 | base_trainer_config = { 134 | "trainer": TrainerTopK, 135 | "dict_class": AutoEncoderTopK, 136 | "activation_dim": activation_dim, 137 | "dict_size": args.dict_ratio * activation_dim, 138 | "auxk_alpha": 1 / 32, 139 | "decay_start": int(steps * 0.8), 140 | "steps": steps, 141 | "seed": 0, 142 | "device": device, 143 | "layer": layer, 144 | "lm_name": lm, 145 | "wandb_name": "AutoEncoderTopK", 146 | } 147 | 148 | trainer_configs = [(base_trainer_config | {"k": k}) for k in args.ks] 149 | 150 | 151 | wandb.init( 152 | entity="josh_engels", 153 | project="Switch", 154 | config={ 155 | f'{trainer_config["wandb_name"]}-{i}': trainer_config 156 | for i, trainer_config in enumerate(trainer_configs) 157 | }, 158 | ) 159 | 160 | trainSAE( 161 | buffer, 162 | trainer_configs=trainer_configs, 163 | save_dir="dictionaries", 164 | log_steps=1, 165 | steps=steps, 166 | save_steps=1000, 167 | ) 168 | 169 | print("Training finished. Evaluating SAE...", flush=True) 170 | for i, trainer_config in enumerate(trainer_configs): 171 | if args.sae_type == "switch": 172 | ae = SwitchAutoEncoder.from_pretrained( 173 | f"dictionaries/{cfg_filename(trainer_config)}/ae.pt", 174 | k=trainer_config["k"], 175 | experts=trainer_config["experts"], 176 | heaviside=trainer_config["heaviside"], 177 | device=device, 178 | ) 179 | metrics = evaluate(ae, buffer, device=device) 180 | log = {} 181 | log.update( 182 | {f'{trainer_config["wandb_name"]}-{i}/{k}': v for k, v in metrics.items()} 183 | ) 184 | wandb.log(log, step=steps + 1) 185 | else: 186 | ae = AutoEncoderTopK.from_pretrained( 187 | f"dictionaries/{cfg_filename(trainer_config)}/ae.pt", 188 | k=trainer_config["k"], 189 | device=device, 190 | ) 191 | metrics = evaluate(ae, buffer, device=device) 192 | log = {} 193 | log.update( 194 | {f'{trainer_config["wandb_name"]}-{i}/{k}': v for k, v in metrics.items()} 195 | ) 196 | wandb.log(log, step=steps + 1) 197 | wandb.finish() 198 | -------------------------------------------------------------------------------- /train-gated.py: -------------------------------------------------------------------------------- 1 | from nnsight import LanguageModel 2 | import torch as t 3 | from dictionary_learning import ActivationBuffer 4 | from dictionary_learning.training import trainSAE 5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename 6 | from dictionary_learning.dictionary import GatedAutoEncoder 7 | from dictionary_learning.trainers.gdm import GatedSAETrainer 8 | from dictionary_learning.evaluation import evaluate 9 | import wandb 10 | import argparse 11 | from config import lm, activation_dim, layer, hf, steps, n_ctxs 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--gpu", required=True) 15 | parser.add_argument('--lr', type=float, default=1e-3) ## 3e-4, 5e-5 16 | parser.add_argument('--dict_ratio', type=int, default=32) 17 | parser.add_argument("--l1_penalties", nargs="+", type=float, required=True) 18 | args = parser.parse_args() 19 | 20 | device = f'cuda:{args.gpu}' 21 | model = LanguageModel(lm, dispatch=True, device_map=device) 22 | submodule = model.transformer.h[layer] 23 | data = hf_dataset_to_generator(hf) 24 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device) 25 | 26 | base_trainer_config = { 27 | 'trainer' : GatedSAETrainer, 28 | 'dict_class' : GatedAutoEncoder, 29 | 'activation_dim' : activation_dim, 30 | 'dict_size' : args.dict_ratio * activation_dim, 31 | 'lr' : args.lr, 32 | 'warmup_steps' : 1000, 33 | 'resample_steps' : None, 34 | 'seed' : 0, 35 | 'device' : device, 36 | 'layer' : layer, 37 | 'lm_name' : lm, 38 | 'wandb_name' : 'GatedSAETrainer' 39 | } 40 | 41 | trainer_configs = [(base_trainer_config | {'l1_penalty': l1_penalty}) for l1_penalty in args.l1_penalties] 42 | 43 | wandb.init(entity="amudide", project="Gated-BigLR", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)}) 44 | 45 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps) 46 | 47 | print("Training finished. Evaluating SAE...", flush=True) 48 | for i, trainer_config in enumerate(trainer_configs): 49 | ae = GatedAutoEncoder.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt', device=device) 50 | metrics = evaluate(ae, buffer, device=device) 51 | log = {} 52 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()}) 53 | wandb.log(log, step=steps+1) 54 | wandb.finish() -------------------------------------------------------------------------------- /train-jump.py: -------------------------------------------------------------------------------- 1 | from nnsight import LanguageModel 2 | import torch as t 3 | from dictionary_learning import ActivationBuffer 4 | from dictionary_learning.training import trainSAE 5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename 6 | from dictionary_learning.dictionary import JumpReluAutoEncoder 7 | from dictionary_learning.trainers.jump_relu import JumpReluTrainer 8 | from dictionary_learning.evaluation import evaluate 9 | import wandb 10 | import argparse 11 | from config import lm, activation_dim, layer, hf, steps, n_ctxs 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--gpu", required=True) 15 | parser.add_argument('--lr', type=float, default=7e-5) 16 | parser.add_argument('--dict_ratio', type=int, default=32) 17 | parser.add_argument("--l0_penalties", nargs="+", type=float, required=True) 18 | args = parser.parse_args() 19 | 20 | device = f'cuda:{args.gpu}' 21 | model = LanguageModel(lm, dispatch=True, device_map=device) 22 | submodule = model.transformer.h[layer] 23 | data = hf_dataset_to_generator(hf) 24 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device) 25 | 26 | base_trainer_config = { 27 | 'trainer' : JumpReluTrainer, 28 | 'dict_class' : JumpReluAutoEncoder, 29 | 'activation_dim' : activation_dim, 30 | 'dict_size' : args.dict_ratio * activation_dim, 31 | 'lr' : args.lr, 32 | 'warmup_steps' : 1000, 33 | 'seed' : 0, 34 | 'device' : device, 35 | 'layer' : layer, 36 | 'lm_name' : lm, 37 | 'wandb_name' : 'JumpReluTrainer' 38 | } 39 | 40 | trainer_configs = [(base_trainer_config | {'l0_penalty': l0_penalty}) for l0_penalty in args.l0_penalties] 41 | 42 | wandb.init(entity="amudide", project="Jump", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)}) 43 | 44 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps) 45 | 46 | print("Training finished. Evaluating SAE...", flush=True) 47 | for i, trainer_config in enumerate(trainer_configs): 48 | ae = JumpReluAutoEncoder.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt', device=device) 49 | metrics = evaluate(ae, buffer, device=device) 50 | log = {} 51 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()}) 52 | wandb.log(log, step=steps+1) 53 | wandb.finish() -------------------------------------------------------------------------------- /train-moe.py: -------------------------------------------------------------------------------- 1 | from nnsight import LanguageModel 2 | import torch as t 3 | from dictionary_learning import ActivationBuffer 4 | from dictionary_learning.training import trainSAE 5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename, str2bool 6 | from dictionary_learning.trainers.moe import MoEAutoEncoder, MoETrainer 7 | from dictionary_learning.evaluation import evaluate 8 | import wandb 9 | import argparse 10 | import itertools 11 | from config import lm, activation_dim, layer, hf, steps, n_ctxs 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--gpu", required=True) 15 | parser.add_argument('--dict_ratio', type=int, default=32) 16 | parser.add_argument("--ks", nargs="+", type=int, required=True) 17 | parser.add_argument("--num_experts", nargs="+", type=int, required=True) 18 | parser.add_argument("--es", nargs="+", type=int, required=True) 19 | parser.add_argument("--heavisides", nargs="+", type=str2bool, required=True) 20 | args = parser.parse_args() 21 | 22 | device = f'cuda:{args.gpu}' 23 | model = LanguageModel(lm, dispatch=True, device_map=device) 24 | submodule = model.transformer.h[layer] 25 | data = hf_dataset_to_generator(hf) 26 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device) 27 | 28 | base_trainer_config = { 29 | 'trainer' : MoETrainer, 30 | 'dict_class' : MoEAutoEncoder, 31 | 'activation_dim' : activation_dim, 32 | 'dict_size' : args.dict_ratio * activation_dim, 33 | 'auxk_alpha' : 1/32, 34 | 'decay_start' : int(steps * 0.8), 35 | 'steps' : steps, 36 | 'seed' : 0, 37 | 'device' : device, 38 | 'layer' : layer, 39 | 'lm_name' : lm, 40 | 'wandb_name' : 'MoEAutoEncoder' 41 | } 42 | 43 | trainer_configs = [(base_trainer_config | {'k': combo[0], 'experts': combo[1], 'e': combo[2], 'heaviside': combo[3]}) for combo in itertools.product(args.ks, args.num_experts, args.es, args.heavisides)] 44 | 45 | wandb.init(entity="amudide", project="MoE", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)}) 46 | 47 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps) 48 | 49 | print("Training finished. Evaluating SAE...", flush=True) 50 | for i, trainer_config in enumerate(trainer_configs): 51 | ae = MoEAutoEncoder.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt', 52 | k = trainer_config['k'], experts = trainer_config['experts'], e = trainer_config['e'], heaviside = trainer_config['heaviside'], device=device) 53 | metrics = evaluate(ae, buffer, device=device) 54 | log = {} 55 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()}) 56 | wandb.log(log, step=steps+1) 57 | wandb.finish() -------------------------------------------------------------------------------- /train-relu.py: -------------------------------------------------------------------------------- 1 | from nnsight import LanguageModel 2 | import torch as t 3 | from dictionary_learning import ActivationBuffer 4 | from dictionary_learning.training import trainSAE 5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename 6 | from dictionary_learning.dictionary import AutoEncoderNew 7 | from dictionary_learning.trainers.standard_new import StandardTrainerNew 8 | from dictionary_learning.evaluation import evaluate 9 | import wandb 10 | import argparse 11 | from config import lm, activation_dim, layer, hf, steps, n_ctxs 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--gpu", required=True) 15 | parser.add_argument('--lr', type=float, default=5e-5) ## 3e-4 16 | parser.add_argument('--dict_ratio', type=int, default=32) 17 | parser.add_argument("--l1_penalties", nargs="+", type=float, required=True) 18 | args = parser.parse_args() 19 | 20 | device = f'cuda:{args.gpu}' 21 | model = LanguageModel(lm, dispatch=True, device_map=device) 22 | submodule = model.transformer.h[layer] 23 | data = hf_dataset_to_generator(hf) 24 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device) 25 | 26 | base_trainer_config = { 27 | 'trainer' : StandardTrainerNew, 28 | 'dict_class' : AutoEncoderNew, 29 | 'activation_dim' : activation_dim, 30 | 'dict_size' : args.dict_ratio * activation_dim, 31 | 'lr' : args.lr, 32 | 'lambda_warm_steps' : int(steps * 0.05), 33 | 'decay_start' : int(steps * 0.8), 34 | 'steps' : steps, 35 | 'seed' : 0, 36 | 'device' : device, 37 | 'wandb_name' : 'StandardTrainerNew_Anthropic' 38 | } 39 | 40 | trainer_configs = [(base_trainer_config | {'l1_penalty': l1_penalty}) for l1_penalty in args.l1_penalties] 41 | 42 | wandb.init(entity="amudide", project="ReLU", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)}) 43 | 44 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps) 45 | 46 | print("Training finished. Evaluating SAE...", flush=True) 47 | for i, trainer_config in enumerate(trainer_configs): 48 | ae = AutoEncoderNew.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt', device=device) 49 | metrics = evaluate(ae, buffer, device=device) 50 | log = {} 51 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()}) 52 | wandb.log(log, step=steps+1) 53 | wandb.finish() -------------------------------------------------------------------------------- /train-switch-1on.py: -------------------------------------------------------------------------------- 1 | from nnsight import LanguageModel 2 | import torch as t 3 | from dictionary_learning import ActivationBuffer 4 | from dictionary_learning.training import trainSAE 5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename, str2bool 6 | from dictionary_learning.trainers.switch1on import SwitchAutoEncoder, SwitchTrainer 7 | from dictionary_learning.evaluation import evaluate 8 | import wandb 9 | import argparse 10 | import itertools 11 | from config import lm, activation_dim, layer, hf, steps, n_ctxs 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--gpu", required=True) 15 | parser.add_argument('--dict_ratio', type=int, default=32) 16 | parser.add_argument("--ks", nargs="+", type=int, required=True) 17 | parser.add_argument("--num_experts", nargs="+", type=int, required=True) 18 | parser.add_argument("--lb_alphas", nargs="+", type=float, default=[3.0]) 19 | parser.add_argument("--heavisides", nargs="+", type=str2bool, default=[False]) 20 | args = parser.parse_args() 21 | 22 | device = f'cuda:{args.gpu}' 23 | model = LanguageModel(lm, dispatch=True, device_map=device) 24 | submodule = model.transformer.h[layer] 25 | data = hf_dataset_to_generator(hf) 26 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device) 27 | 28 | base_trainer_config = { 29 | 'trainer' : SwitchTrainer, 30 | 'dict_class' : SwitchAutoEncoder, 31 | 'activation_dim' : activation_dim, 32 | 'dict_size' : args.dict_ratio * activation_dim, 33 | 'auxk_alpha' : 1/32, 34 | 'decay_start' : int(steps * 0.8), 35 | 'steps' : steps, 36 | 'seed' : 0, 37 | 'device' : device, 38 | 'layer' : layer, 39 | 'lm_name' : lm, 40 | 'wandb_name' : 'SwitchAutoEncoder' 41 | } 42 | 43 | trainer_configs = [(base_trainer_config | {'k': combo[0], 'experts': combo[1], 'heaviside': combo[2], 'lb_alpha': combo[3]}) for combo in itertools.product(args.ks, args.num_experts, args.heavisides, args.lb_alphas)] 44 | 45 | wandb.init(entity="amudide", project="Switch (1 Always On)", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)}) 46 | 47 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps) 48 | 49 | print("Training finished. Evaluating SAE...", flush=True) 50 | for i, trainer_config in enumerate(trainer_configs): 51 | ae = SwitchAutoEncoder.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt', 52 | k = trainer_config['k'], experts = trainer_config['experts'], heaviside = trainer_config['heaviside'], device=device) 53 | metrics = evaluate(ae, buffer, device=device) 54 | log = {} 55 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()}) 56 | wandb.log(log, step=steps+1) 57 | wandb.finish() -------------------------------------------------------------------------------- /train-switch-flop.py: -------------------------------------------------------------------------------- 1 | from nnsight import LanguageModel 2 | import torch as t 3 | from dictionary_learning import ActivationBuffer 4 | from dictionary_learning.training import trainSAE 5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename, str2bool 6 | from dictionary_learning.trainers.switch import SwitchAutoEncoder, SwitchTrainer 7 | from dictionary_learning.evaluation import evaluate 8 | import wandb 9 | import argparse 10 | import itertools 11 | from config import lm, activation_dim, layer, hf, steps, n_ctxs 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--gpu", required=True) 15 | parser.add_argument('--dict_ratio', type=int, default=32) 16 | parser.add_argument("--ks", nargs="+", type=int, required=True) 17 | parser.add_argument("--num_experts", nargs="+", type=int, required=True) 18 | parser.add_argument("--lb_alphas", nargs="+", type=float, default=[3.0]) 19 | parser.add_argument("--heavisides", nargs="+", type=str2bool, default=[False]) 20 | args = parser.parse_args() 21 | 22 | device = f'cuda:{args.gpu}' 23 | model = LanguageModel(lm, dispatch=True, device_map=device) 24 | submodule = model.transformer.h[layer] 25 | data = hf_dataset_to_generator(hf) 26 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, out_batch_size=2048, n_ctxs=n_ctxs//2, device=device) ## 1/2 context, 1/4 batch size to save memory 27 | 28 | base_trainer_config = { 29 | 'trainer' : SwitchTrainer, 30 | 'dict_class' : SwitchAutoEncoder, 31 | 'activation_dim' : activation_dim, 32 | 'auxk_alpha' : 1/32, 33 | 'decay_start' : int(steps * 0.8), 34 | 'steps' : steps, 35 | 'seed' : 0, 36 | 'device' : device, 37 | 'layer' : layer, 38 | 'lm_name' : lm, 39 | 'wandb_name' : 'SwitchAutoEncoder' 40 | } 41 | 42 | trainer_configs = [(base_trainer_config | {'k': combo[0], 'experts': combo[1], 'dict_size' : args.dict_ratio * activation_dim * combo[1], 'heaviside': combo[2], 'lb_alpha': combo[3]}) for combo in itertools.product(args.ks, args.num_experts, args.heavisides, args.lb_alphas)] 43 | 44 | wandb.init(entity="amudide", project="Switch (FLOP Matched, LB)", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)}) 45 | 46 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps) 47 | 48 | print("Training finished. Evaluating SAE...", flush=True) 49 | for i, trainer_config in enumerate(trainer_configs): 50 | ae = SwitchAutoEncoder.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt', 51 | k = trainer_config['k'], experts = trainer_config['experts'], heaviside = trainer_config['heaviside'], device=device) 52 | metrics = evaluate(ae, buffer, device=device) 53 | log = {} 54 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()}) 55 | wandb.log(log, step=steps+1) 56 | wandb.finish() -------------------------------------------------------------------------------- /train-switch.py: -------------------------------------------------------------------------------- 1 | from nnsight import LanguageModel 2 | import torch as t 3 | from dictionary_learning import ActivationBuffer 4 | from dictionary_learning.training import trainSAE 5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename, str2bool 6 | from dictionary_learning.trainers.switch import SwitchAutoEncoder, SwitchTrainer 7 | from dictionary_learning.evaluation import evaluate 8 | import wandb 9 | import argparse 10 | import itertools 11 | from config import lm, activation_dim, layer, hf, steps, n_ctxs 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--gpu", required=True) 15 | parser.add_argument('--dict_ratio', type=int, default=32) 16 | parser.add_argument("--ks", nargs="+", type=int, required=True) 17 | parser.add_argument("--num_experts", nargs="+", type=int, required=True) 18 | parser.add_argument("--lb_alphas", nargs="+", type=float, default=[3.0]) 19 | parser.add_argument("--heavisides", nargs="+", type=str2bool, default=[False]) 20 | args = parser.parse_args() 21 | 22 | device = f'cuda:{args.gpu}' 23 | model = LanguageModel(lm, dispatch=True, device_map=device) 24 | submodule = model.transformer.h[layer] 25 | data = hf_dataset_to_generator(hf) 26 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device) 27 | 28 | base_trainer_config = { 29 | 'trainer' : SwitchTrainer, 30 | 'dict_class' : SwitchAutoEncoder, 31 | 'activation_dim' : activation_dim, 32 | 'dict_size' : args.dict_ratio * activation_dim, 33 | 'auxk_alpha' : 1/32, 34 | 'decay_start' : int(steps * 0.8), 35 | 'steps' : steps, 36 | 'seed' : 0, 37 | 'device' : device, 38 | 'layer' : layer, 39 | 'lm_name' : lm, 40 | 'wandb_name' : 'SwitchAutoEncoder' 41 | } 42 | 43 | trainer_configs = [(base_trainer_config | {'k': combo[0], 'experts': combo[1], 'heaviside': combo[2], 'lb_alpha': combo[3]}) for combo in itertools.product(args.ks, args.num_experts, args.heavisides, args.lb_alphas)] 44 | 45 | wandb.init(entity="amudide", project="Switch (LB)", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)}) 46 | 47 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps) 48 | 49 | print("Training finished. Evaluating SAE...", flush=True) 50 | for i, trainer_config in enumerate(trainer_configs): 51 | ae = SwitchAutoEncoder.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt', 52 | k = trainer_config['k'], experts = trainer_config['experts'], heaviside = trainer_config['heaviside'], device=device) 53 | metrics = evaluate(ae, buffer, device=device) 54 | log = {} 55 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()}) 56 | wandb.log(log, step=steps+1) 57 | wandb.finish() -------------------------------------------------------------------------------- /train-topk.py: -------------------------------------------------------------------------------- 1 | from nnsight import LanguageModel 2 | import torch as t 3 | from dictionary_learning import ActivationBuffer 4 | from dictionary_learning.training import trainSAE 5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename 6 | from dictionary_learning.trainers.top_k import AutoEncoderTopK, TrainerTopK 7 | from dictionary_learning.evaluation import evaluate 8 | import wandb 9 | import argparse 10 | from config import lm, activation_dim, layer, hf, steps, n_ctxs 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--gpu", required=True) 14 | parser.add_argument('--dict_ratio', type=int, default=32) 15 | parser.add_argument("--ks", nargs="+", type=int, required=True) 16 | args = parser.parse_args() 17 | 18 | device = f'cuda:{args.gpu}' 19 | model = LanguageModel(lm, dispatch=True, device_map=device) 20 | submodule = model.transformer.h[layer] 21 | data = hf_dataset_to_generator(hf) 22 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device) 23 | 24 | base_trainer_config = { 25 | 'trainer' : TrainerTopK, 26 | 'dict_class' : AutoEncoderTopK, 27 | 'activation_dim' : activation_dim, 28 | 'dict_size' : args.dict_ratio * activation_dim, 29 | 'auxk_alpha' : 1/32, 30 | 'decay_start' : int(steps * 0.8), 31 | 'steps' : steps, 32 | 'seed' : 0, 33 | 'device' : device, 34 | 'layer' : layer, 35 | 'lm_name' : lm, 36 | 'wandb_name' : 'AutoEncoderTopK' 37 | } 38 | 39 | trainer_configs = [(base_trainer_config | {'k': k}) for k in args.ks] 40 | 41 | wandb.init(entity="amudide", project="TopK (Frequent Log)", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)}) 42 | 43 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps) 44 | 45 | print("Training finished. Evaluating SAE...", flush=True) 46 | for i, trainer_config in enumerate(trainer_configs): 47 | ae = AutoEncoderTopK.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt', k = trainer_config['k'], device=device) 48 | metrics = evaluate(ae, buffer, device=device) 49 | log = {} 50 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()}) 51 | wandb.log(log, step=steps+1) 52 | wandb.finish() --------------------------------------------------------------------------------