├── .gitignore
├── .gitmodules
├── .vscode
└── settings.json
├── README.md
├── auto_interp
├── compare_dictionaries.py
├── generate_scripts.py
├── results.pkl
└── sae_feature_evals.py
├── compare_geometry
├── examine_duplicate_features.py
├── find_duplicate_features.py
└── paper_plots.py
├── config.py
├── dictionary_learning
├── .DS_Store
├── .gitignore
├── __init__.py
├── buffer.py
├── config.py
├── dictionary.py
├── evaluation.py
├── grad_pursuit.py
├── interp.py
├── kernels.py
├── pretrained_dictionary_downloader.sh
├── requirements.txt
├── trainers
│ ├── gdm.py
│ ├── jump_relu.py
│ ├── moe.py
│ ├── standard.py
│ ├── standard_new.py
│ ├── switch.py
│ ├── switch1on.py
│ ├── top_k.py
│ └── trainer.py
├── training.py
└── utils.py
├── other
└── lobes.py
├── results
├── 1on
│ ├── 1on.ipynb
│ ├── l0_deltace.png
│ ├── l0_mse.png
│ ├── l0_recovered.png
│ ├── primary-switch-fast.csv
│ └── switch-heavy.csv
├── 1on_lb
│ ├── 1on_lb.csv
│ ├── alpha_deltace.png
│ ├── alpha_lossrec.png
│ ├── alpha_mse.png
│ └── lb.ipynb
├── efficiency
│ ├── efficiency-switch.csv
│ ├── efficiency-topk.csv
│ ├── efficiency.ipynb
│ └── efficiency.png
├── gated_lr
│ ├── gated_l0_lossrec.png
│ ├── gated_l0_mse.png
│ ├── gated_mse_deltace.png
│ ├── primary-gated-1e-3.csv
│ ├── primary-gated-3e-4.csv
│ ├── primary-gated-5e-5.csv
│ └── primary-gated.ipynb
├── heaviside_softmax
│ ├── experts.ipynb
│ ├── experts16h.csv
│ ├── experts16s.csv
│ ├── experts_l0_deltace.png
│ ├── experts_l0_lossrec.png
│ ├── experts_l0_mse.png
│ └── primary-experts.csv
├── load_balance
│ ├── alpha_deltace.png
│ ├── alpha_lossrec.png
│ ├── alpha_mse.pdf
│ ├── alpha_mse.png
│ ├── lb-sweep.csv
│ └── lb.ipynb
└── primary
│ ├── big-plot.ipynb
│ ├── flopmatch_l0_deltace.png
│ ├── flopmatch_l0_lossrec.png
│ ├── flopmatch_l0_mse.png
│ ├── flopmatch_mse_deltace.png
│ ├── l0_deltace.png
│ ├── l0_lossrec.png
│ ├── l0_mse.png
│ ├── mse_deltace.png
│ ├── primary-flop-match.ipynb
│ ├── primary-gated-clean.csv
│ ├── primary-gated.csv
│ ├── primary-relu-clean.csv
│ ├── primary-relu.csv
│ ├── primary-switch-fast.csv
│ ├── primary-switch-flop.csv
│ ├── primary-topk-clean.csv
│ ├── primary-topk.csv
│ ├── primary.ipynb
│ ├── switch_sae_l0_lr.pdf
│ ├── switch_sae_l0_mse.pdf
│ ├── switch_sae_pareto.pdf
│ ├── switch_sae_pareto_flop.pdf
│ └── switch_sae_pareto_width.pdf
├── save_activations.py
├── scaling_laws
└── attempt0
│ ├── plot.ipynb
│ ├── train-dense-topk.py
│ ├── train-dense-topk.sh
│ ├── train-switch-topk.py
│ └── train-switch-topk.sh
├── speed.ipynb
├── table
├── create_table_script.py
├── eval_table.py
├── modified_eval.py
├── run_parallel.sh
└── train_switch_table.py
├── train-gated.py
├── train-jump.py
├── train-moe.py
├── train-relu.py
├── train-switch-1on.py
├── train-switch-flop.py
├── train-switch.py
└── train-topk.py
/.gitignore:
--------------------------------------------------------------------------------
1 | dictionaries*
2 | wandb
3 | .DS_Store
4 | weights*
5 |
6 | # Byte-compiled / optimized / DLL files
7 | __pycache__/
8 | *.py[cod]
9 | *$py.class
10 |
11 | # C extensions
12 | *.so
13 |
14 | # Distribution / packaging
15 | .Python
16 | build/
17 | develop-eggs/
18 | dist/
19 | downloads/
20 | eggs/
21 | .eggs/
22 | lib/
23 | lib64/
24 | parts/
25 | sdist/
26 | var/
27 | wheels/
28 | share/python-wheels/
29 | *.egg-info/
30 | .installed.cfg
31 | *.egg
32 | MANIFEST
33 |
34 | # PyInstaller
35 | # Usually these files are written by a python script from a template
36 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
37 | *.manifest
38 | *.spec
39 |
40 | # Installer logs
41 | pip-log.txt
42 | pip-delete-this-directory.txt
43 |
44 | # Unit test / coverage reports
45 | htmlcov/
46 | .tox/
47 | .nox/
48 | .coverage
49 | .coverage.*
50 | .cache
51 | nosetests.xml
52 | coverage.xml
53 | *.cover
54 | *.py,cover
55 | .hypothesis/
56 | .pytest_cache/
57 | cover/
58 |
59 | # Translations
60 | *.mo
61 | *.pot
62 |
63 | # Django stuff:
64 | *.log
65 | local_settings.py
66 | db.sqlite3
67 | db.sqlite3-journal
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | .pybuilder/
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # poetry
103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104 | # This is especially recommended for binary packages to ensure reproducibility, and is more
105 | # commonly ignored for libraries.
106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107 | #poetry.lock
108 |
109 | # pdm
110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111 | #pdm.lock
112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113 | # in version control.
114 | # https://pdm.fming.dev/#use-with-ide
115 | .pdm.toml
116 |
117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118 | __pypackages__/
119 |
120 | # Celery stuff
121 | celerybeat-schedule
122 | celerybeat.pid
123 |
124 | # SageMath parsed files
125 | *.sage.py
126 |
127 | # Environments
128 | .env
129 | .venv
130 | env/
131 | venv/
132 | ENV/
133 | env.bak/
134 | venv.bak/
135 |
136 | # Spyder project settings
137 | .spyderproject
138 | .spyproject
139 |
140 | # Rope project settings
141 | .ropeproject
142 |
143 | # mkdocs documentation
144 | /site
145 |
146 | # mypy
147 | .mypy_cache/
148 | .dmypy.json
149 | dmypy.json
150 |
151 | # Pyre type checker
152 | .pyre/
153 |
154 | # pytype static type analyzer
155 | .pytype/
156 |
157 | # Cython debug symbols
158 | cython_debug/
159 |
160 | # PyCharm
161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163 | # and can be added to the global gitignore or merged into this file. For a more nuclear
164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165 | #.idea/
166 |
167 | # venv
168 | switch
169 |
170 | # plots
171 | plots
172 | *.png
173 |
174 | # data
175 | data
176 |
177 | # Pytorch bins
178 | *.pt
179 |
180 | # Pickle filees
181 | *.pkl
182 | !auto_interp/*.pkl
183 |
184 | # Generated shell files
185 | auto_interp/*.sh
186 |
187 | other_dictionaries
188 |
189 | table/results.csv
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "sae-auto-interp"]
2 | path = sae-auto-interp
3 | url = https://github.com/EleutherAI/sae-auto-interp/
4 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.analysis.extraPaths": [
3 | "./sae-auto-interp"
4 | ]
5 | }
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | Efficient Dictionary Learning with Switch Sparse Autoencoders (SAEs)
3 |
4 |
5 | More soon!
6 |
7 | ## Credits
8 | This repository is adapted from [dictionary_learning](https://github.com/saprmarks/dictionary_learning) by Samuel Marks and Aaron Mueller.
9 |
--------------------------------------------------------------------------------
/auto_interp/compare_dictionaries.py:
--------------------------------------------------------------------------------
1 | # %%
2 |
3 | import pickle
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import os
7 | from matplotlib.ticker import FixedLocator, FixedFormatter
8 |
9 | flop_matched_color_dict = {
10 | 1: "#9467bd",
11 | 2: "#ff7f0e",
12 | 4: "#ff6600",
13 | 8: "#ff3300",
14 | }
15 |
16 | fixed_width_color_dict = {
17 | 1: "#9467bd",
18 | 16: "#ff7f0e",
19 | 32: "#ff6600",
20 | 64: "#ff4d00",
21 | 128: "#ff3300",
22 | }
23 |
24 | dictionaries = [
25 | "dictionaries/topk/k64",
26 | "dictionaries/fixed-width/16_experts/k64",
27 | "dictionaries/fixed-width/32_experts/k64",
28 | "dictionaries/fixed-width/64_experts/k64",
29 | "dictionaries/fixed-width/128_experts/k64",
30 | "dictionaries/flop-matched/2_experts/k64",
31 | "dictionaries/flop-matched/4_experts/k64",
32 | "dictionaries/flop-matched/8_experts/k64",
33 | ]
34 | results = pickle.load(open("results.pkl", "rb"))
35 |
36 | def confidence_interval(successes, total, z=1.96):
37 | p = successes / total
38 | se = np.sqrt(p * (1 - p) / total)
39 | return z * se
40 |
41 | def process_results(results, dictionary):
42 | total_per_quantile = [0 for _ in range(11)]
43 | total_correct_per_quantile = [0 for _ in range(11)]
44 | for quantile_positives, quantile_totals in results[0]:
45 | for i in range(11):
46 | total_per_quantile[i] += quantile_totals[i]
47 | total_correct_per_quantile[i] += quantile_positives[i]
48 |
49 | total_negative = 0
50 | total_negative_correct = 0
51 | for negative_positives, negative_totals in results[1]:
52 | total_negative += negative_totals
53 | total_negative_correct += negative_positives
54 |
55 | average_per_quantile = [
56 | total_correct_per_quantile[i] / total_per_quantile[i] for i in range(1, 11)
57 | ]
58 | average_negative = total_negative_correct / total_negative
59 |
60 | ci_per_quantile = [
61 | confidence_interval(total_correct_per_quantile[i], total_per_quantile[i])
62 | for i in range(1, 11)
63 | ]
64 | ci_negative = confidence_interval(total_negative_correct, total_negative)
65 |
66 | x = ["Not"] + ["Q" + str(i) for i in range(1, 11)]
67 | y = [1 - average_negative] + average_per_quantile[::-1]
68 | ci = [ci_negative] + ci_per_quantile[::-1]
69 |
70 | return x, y, ci, dictionary
71 |
72 | # Create two subplots with the new figure size
73 | fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(5.5, 2.5))
74 |
75 | topk_data = None
76 |
77 | for results, dictionary in zip(results, dictionaries):
78 | x, y, ci, dict_name = process_results(results, dictionary)
79 |
80 | for i in [2, 4, 8, 16, 32, 64, 128]:
81 | if f"/{i}_experts" in dict_name:
82 | label_name = f"Switch: {i}e"
83 | break
84 | if "topk" in dict_name:
85 | label_name = "Topk"
86 |
87 | if "topk" in dict_name:
88 | topk_data = (x, y, ci, dict_name)
89 | color = flop_matched_color_dict[1] # Use color for 1 expert
90 | ax1.errorbar(x, y, yerr=ci, fmt='-o', capsize=5, label=label_name, color=color, markersize=5)
91 | ax2.errorbar(x, y, yerr=ci, fmt='-o', capsize=5, label=label_name, color=color, markersize=5)
92 | elif "flop-matched" in dict_name:
93 | ax = ax1
94 | color_dict = flop_matched_color_dict
95 | expert_count = int(dict_name.split("/")[2].split("_")[0])
96 | color = color_dict.get(expert_count, "#000000") # Default to black if not found
97 | ax.errorbar(x, y, yerr=ci, fmt='-o', capsize=5, label=label_name, color=color, markersize=5)
98 | else:
99 | ax = ax2
100 | color_dict = fixed_width_color_dict
101 | expert_count = int(dict_name.split("/")[2].split("_")[0])
102 | color = color_dict.get(expert_count, "#000000") # Default to black if not found
103 | ax.errorbar(x, y, yerr=ci, fmt='-o', capsize=5, label=label_name, color=color, markersize=5)
104 |
105 | for ax in (ax1, ax2):
106 | ax.set_xlabel("Quantiles", fontsize=8)
107 | ax.set_ylabel('Accuracy', fontsize=8)
108 | ax.grid(True, which="both", ls="--", linewidth=0.5)
109 | ax.legend(loc='upper center', prop={'size': 5.5})
110 |
111 | # Set y-axis limits and ticks
112 | ax.set_ylim(0.2, 1.0) # Adjust these values as needed
113 |
114 | # Adjust tick label size
115 | ax.tick_params(axis='both', labelsize=6.5)
116 |
117 | # Remove minor ticks
118 | ax.minorticks_off()
119 |
120 | # Tilt all x-axis labels
121 | for tick in ax.get_xticklabels():
122 | tick.set_rotation(45)
123 | tick.set_va('top')
124 |
125 |
126 | # Turn off ticks on y axis of ax2
127 | ax2.yaxis.set_ticks_position('none')
128 | ax2.yaxis.set_tick_params(size=0)
129 | ax2.yaxis.set_ticklabels([])
130 | ax2.set_ylabel('')
131 |
132 | ax1.set_title("FLOP-Matched", fontsize=9)
133 | ax2.set_title("Width-Matched", fontsize=9)
134 |
135 | os.makedirs("plots", exist_ok=True)
136 | plt.savefig("plots/detection_split.pdf", bbox_inches='tight')
137 | plt.show()
138 | # %%
139 |
--------------------------------------------------------------------------------
/auto_interp/generate_scripts.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import glob
3 |
4 | import os
5 |
6 | # Change dir to two dirs up
7 | file_path = os.path.abspath(__file__)
8 | os.chdir(os.path.dirname(os.path.dirname(file_path)))
9 |
10 | # Get all paths in dictionaries
11 | dictionaries = glob.glob("dictionaries/**/*.pt", recursive=True)
12 | parent_dirs = set(["/".join(d.split("/")[:-1]) for d in dictionaries])
13 |
14 | # Change dir back
15 | os.chdir(os.path.dirname(file_path))
16 |
17 | # Write out run_all_feature_eval_generate.sh
18 | with open("run_all_feature_eval_generate.sh", "w") as f:
19 | for parent_dir in parent_dirs:
20 | f.write(f"python sae_feature_evals.py --sae_path {parent_dir} --to_do generate\n")
21 |
22 | # Write out run_all_feature_eval_eval.sh
23 | with open("run_all_feature_eval_eval.sh", "w") as f:
24 | for parent_dir in parent_dirs:
25 | f.write(f"python sae_feature_evals.py --sae_path {parent_dir} --to_do eval\n")
26 |
27 | # Write out run_all_feature_eval_generate_small.sh
28 | # Only k = 64, and (fixed-width or topk)
29 | with open("run_all_feature_eval_generate_small.sh", "w") as f:
30 | for parent_dir in parent_dirs:
31 | if "k64" in parent_dir:
32 | f.write(f"python sae_feature_evals.py --sae_path {parent_dir} --to_do generate\n")
33 |
34 | # Write out run_all_feature_eval_eval_small.sh
35 | # Only k = 64, and (fixed-width or topk)
36 | with open("run_all_feature_eval_eval_small.sh", "w") as f:
37 | for parent_dir in parent_dirs:
38 | if "k64" in parent_dir:
39 | f.write(f"python sae_feature_evals.py --sae_path {parent_dir} --to_do eval\n")
--------------------------------------------------------------------------------
/auto_interp/results.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/auto_interp/results.pkl
--------------------------------------------------------------------------------
/auto_interp/sae_feature_evals.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | import sys
4 | sys.path.insert(1, os.path.join(sys.path[0], '..'))
5 |
6 | from functools import partial
7 | from nnsight import LanguageModel
8 | import asyncio
9 |
10 | from dictionary_learning.trainers.switch import SwitchAutoEncoder
11 | from dictionary_learning.trainers.top_k import AutoEncoderTopK
12 |
13 | import torch
14 | import einops
15 | import os
16 |
17 | from functools import partial
18 |
19 | import torch
20 | from nnsight import LanguageModel
21 |
22 | from sae_auto_interp.autoencoders.wrapper import AutoencoderLatents
23 | from sae_auto_interp.autoencoders.OpenAI import Autoencoder
24 |
25 |
26 | from sae_auto_interp.features import FeatureCache
27 | from sae_auto_interp.features.features import FeatureRecord
28 | from sae_auto_interp.utils import load_tokenized_data
29 |
30 |
31 | from sae_auto_interp.features import FeatureDataset, pool_max_activation_windows, random_activation_windows, sample
32 | from sae_auto_interp.config import FeatureConfig, ExperimentConfig
33 |
34 |
35 | from sae_auto_interp.explainers import SimpleExplainer
36 | from sae_auto_interp.scorers import RecallScorer
37 | from tqdm import tqdm
38 | import pickle
39 |
40 |
41 | from sae_auto_interp.clients import OpenRouter, Local
42 | from sae_auto_interp.utils import display
43 | import argparse
44 |
45 | # %%
46 |
47 | # Change dir to folder one level up from this file
48 | this_dir = os.path.dirname(os.path.abspath(__file__))
49 | one_level_up = os.path.dirname(this_dir)
50 | os.chdir(one_level_up)
51 |
52 | # %%
53 |
54 | CTX_LEN = 128
55 | BATCH_SIZE = 32
56 | N_TOKENS = 10_000_000
57 | N_SPLITS = 2
58 | NUM_FEATURES_TO_TEST = 1000
59 |
60 | device = "cuda:1"
61 |
62 | # Set torch seed
63 | torch.manual_seed(0)
64 |
65 | # %%
66 |
67 | try:
68 | from IPython import get_ipython # type: ignore
69 |
70 | ipython = get_ipython()
71 | assert ipython is not None
72 | ipython.run_line_magic("load_ext", "autoreload")
73 | ipython.run_line_magic("autoreload", "2")
74 |
75 | is_notebook = True
76 | except:
77 | is_notebook = False
78 |
79 | if not is_notebook:
80 | argparser = argparse.ArgumentParser()
81 | argparser.add_argument("--sae_path", type=str, required=True)
82 | argparser.add_argument("--to_do", type=str, required=True, choices=["generate", "eval", "both"])
83 |
84 | args = argparser.parse_args()
85 | SAE_PATH = args.sae_path
86 | to_do = args.to_do
87 | else:
88 | SAE_PATH = "dictionaries/topk/k64"
89 | to_do = "both"
90 | # SAE_PATH = "dictionaries/fixed-width/16_experts/k64"
91 |
92 | RAW_ACTIVATIONS_DIR = f"/media/jengels/sda/switch/{SAE_PATH}"
93 | SAVE_FILE = f"/media/jengels/sda/switch/{SAE_PATH}/results.pkl"
94 | FINAL_SAVE_FILE = f"/media/jengels/sda/switch/{SAE_PATH}/final_results.pkl"
95 |
96 | # %%
97 |
98 |
99 | if "topk" in SAE_PATH:
100 | k = int(SAE_PATH.split("/")[-1][1:])
101 | ae = AutoEncoderTopK.from_pretrained(f"{SAE_PATH}/ae.pt", k=k, device=device)
102 | else:
103 | num_experts = int(SAE_PATH.split("/")[-2].split("_")[0])
104 | k = int(SAE_PATH.split("/")[-1][1:])
105 | ae = SwitchAutoEncoder.from_pretrained(f"{SAE_PATH}/ae.pt", k=k, experts=num_experts, device=device)
106 |
107 | ae.to(device)
108 |
109 | model = LanguageModel("openai-community/gpt2", device_map=device, dispatch=True)
110 |
111 | # TODO: Ideally use openwebtext
112 | tokens = load_tokenized_data(
113 | CTX_LEN,
114 | model.tokenizer,
115 | "kh4dien/fineweb-100m-sample",
116 | "train[:15%]",
117 | )
118 |
119 | # %%
120 |
121 |
122 | WIDTH = ae.dict_size
123 | # Get NUM_FEATURES_TO_TEST random features to test without replacement
124 | random_features = torch.randperm(WIDTH)[:NUM_FEATURES_TO_TEST]
125 |
126 | # %%
127 |
128 | generate = to_do in ["generate", "both"]
129 | if generate:
130 |
131 | def _forward(ae, x):
132 | _, _, top_acts, top_indices = ae.forward(x, output_features="all")
133 |
134 | expanded = torch.zeros(top_acts.shape[0], WIDTH, device=device)
135 | expanded.scatter_(1, top_indices, top_acts)
136 |
137 | expanded = einops.rearrange(expanded, "(b c) w -> b c w", b=x.shape[0], c=x.shape[1])
138 | return expanded
139 |
140 | # We can simply add the new module as an attribute to an existing
141 | # submodule on GPT-2's module tree.
142 | submodule = model.transformer.h[8]
143 | submodule.ae = AutoencoderLatents(
144 | ae,
145 | partial(_forward, ae),
146 | width=ae.dict_size
147 | )
148 |
149 | with model.edit(" ", inplace=True):
150 | acts = submodule.output[0]
151 | submodule.ae(acts, hook=True)
152 |
153 | with model.trace("hello, my name is"):
154 | latents = submodule.ae.output.save()
155 |
156 | module_path = submodule.path
157 |
158 | submodule_dict = {module_path : submodule}
159 | module_filter = {module_path : random_features.to(device)}
160 |
161 | cache = FeatureCache(
162 | model,
163 | submodule_dict,
164 | batch_size=BATCH_SIZE,
165 | filters=module_filter
166 | )
167 |
168 | cache.run(N_TOKENS, tokens)
169 |
170 | cache.save_splits(
171 | n_splits=N_SPLITS,
172 | save_dir=RAW_ACTIVATIONS_DIR,
173 | )
174 |
175 | # %%
176 |
177 |
178 | cfg = FeatureConfig(
179 | width = WIDTH,
180 | min_examples = 200,
181 | max_examples = 2_000,
182 | example_ctx_len = CTX_LEN,
183 | n_splits = 2,
184 | )
185 |
186 | sample_cfg = ExperimentConfig(n_random=50)
187 |
188 | # This is a hack because this isn't currently defined in the repo
189 | sample_cfg.chosen_quantile = 0
190 |
191 | dataset = FeatureDataset(
192 | raw_dir=RAW_ACTIVATIONS_DIR,
193 | cfg=cfg,
194 | )
195 | # %%
196 |
197 | if to_do == "generate":
198 | exit()
199 |
200 | # %%
201 |
202 | # Define these functions here so we don't need to edit the functions in the git submodule
203 | def default_constructor(
204 | record,
205 | tokens,
206 | buffer_output,
207 | n_random: int,
208 | cfg: FeatureConfig
209 | ):
210 | pool_max_activation_windows(
211 | record,
212 | tokens=tokens,
213 | buffer_output=buffer_output,
214 | cfg=cfg
215 | )
216 |
217 | random_activation_windows(
218 | record,
219 | tokens=tokens,
220 | buffer_output=buffer_output,
221 | n_random=n_random,
222 | ctx_len=cfg.example_ctx_len,
223 | )
224 |
225 |
226 | constructor=partial(
227 | default_constructor,
228 | n_random=sample_cfg.n_random,
229 | tokens=tokens,
230 | cfg=cfg
231 | )
232 |
233 | sampler = partial(
234 | sample,
235 | cfg=sample_cfg
236 | )
237 |
238 |
239 | def load(
240 | dataset,
241 | constructor,
242 | sampler,
243 | transform = None
244 | ):
245 | def _process(buffer_output):
246 | record = FeatureRecord(buffer_output.feature)
247 | if constructor is not None:
248 | constructor(record=record, buffer_output=buffer_output)
249 |
250 | if sampler is not None:
251 | sampler(record)
252 |
253 | if transform is not None:
254 | transform(record)
255 |
256 | return record
257 |
258 | for buffer in dataset.buffers:
259 | for data in buffer:
260 | if data is not None:
261 | yield _process(data)
262 | # %%
263 |
264 | record_iterator = load(constructor=constructor, sampler=sampler, dataset=dataset, transform=None)
265 |
266 | # next_record = next(record_iterator)
267 |
268 | # display(next_record, model.tokenizer, n=5)
269 | # %%
270 |
271 | # Command to run: vllm serve hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4 --max_model_len 10000 --tensor-parallel-size 2
272 | client = Local("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4")
273 |
274 | # %%
275 |
276 | async def run_async():
277 |
278 | global record_iterator
279 |
280 |
281 | positive_scores = []
282 | negative_scores = []
283 | explanations = []
284 | feature_ids = []
285 | total_positive_score = 0
286 | total_negative_score = 0
287 | total_evaluated = 0
288 |
289 | bar = tqdm(record_iterator, total=NUM_FEATURES_TO_TEST)
290 |
291 |
292 |
293 | for record in bar:
294 |
295 | explainer = SimpleExplainer(
296 | client,
297 | model.tokenizer,
298 | # max_new_tokens=50,
299 | max_tokens=50,
300 | temperature=0.0
301 | )
302 |
303 | explainer_result = await explainer(record)
304 | # explainer_result = asyncio.run(explainer(record))
305 |
306 | # print(explainer_result.explanation)
307 | record.explanation = explainer_result.explanation
308 |
309 |
310 | scorer = RecallScorer(
311 | client,
312 | model.tokenizer,
313 | max_tokens=25,
314 | temperature=0.0,
315 | batch_size=4,
316 | )
317 |
318 |
319 | score = await scorer(record)
320 |
321 | quantile_positives = [0 for _ in range(11)]
322 | quantile_totals = [0 for _ in range(11)]
323 | negative_positives = 0
324 | negative_totals = 0
325 | for score_instance in score.score:
326 | quantile = score_instance.distance
327 | if quantile != -1 and score_instance.prediction != -1:
328 | quantile_totals[quantile] += 1
329 | if score_instance.prediction == 1:
330 | quantile_positives[quantile] += 1
331 | if quantile == -1 and score_instance.prediction != -1:
332 | negative_totals += 1
333 | if score_instance.prediction == 1:
334 | negative_positives += 1
335 |
336 | positive_scores.append((quantile_positives, quantile_totals))
337 | negative_scores.append((negative_positives, negative_totals))
338 |
339 | if (sum(quantile_totals) == 0) or (negative_totals == 0):
340 | continue
341 |
342 | total_positive_score += sum(quantile_positives) / sum(quantile_totals)
343 | total_negative_score += negative_positives / negative_totals
344 | total_evaluated += 1
345 |
346 | bar.set_description(f"Positive Recall: {total_positive_score / total_evaluated}, Negative Recall: {total_negative_score / total_evaluated}")
347 |
348 | print(quantile_positives, quantile_totals)
349 |
350 | explanations.append(record.explanation)
351 |
352 | feature_ids.append(record.feature.feature_index)
353 |
354 |
355 | with open(SAVE_FILE, "wb") as f:
356 | pickle.dump((positive_scores, negative_scores, explanations, feature_ids), f)
357 |
358 | with open(FINAL_SAVE_FILE, "wb") as f:
359 | pickle.dump((positive_scores, negative_scores, explanations, feature_ids), f)
360 |
361 |
362 | # Switch comment when running in notebook/command line
363 | # await run_async()
364 | asyncio.run(run_async())
365 |
366 | # %%
367 |
--------------------------------------------------------------------------------
/compare_geometry/examine_duplicate_features.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | import sys
4 | sys.path.insert(1, os.path.join(sys.path[0], '..'))
5 |
6 | import pandas as pd
7 | import torch
8 | from dictionary_learning.trainers.switch import SwitchAutoEncoder
9 | from dictionary_learning.trainers.top_k import AutoEncoderTopK
10 | import einops
11 | from tqdm import tqdm
12 | from transformers import GPT2Tokenizer
13 |
14 | torch.set_grad_enabled(False)
15 |
16 | # %%
17 |
18 | # device = "cuda:0"
19 | device = "cpu"
20 | duplicate_features = pd.read_csv('../data/duplicates.csv')
21 | data = torch.load("../data/gpt2_activations_layer8.pt", map_location=device)
22 | tokens = torch.load("../data/gpt2_tokens.pt", map_location=device)
23 | save_location = "../data/top_activating_for_dupes_layer8.pt"
24 | ctx_len = 128
25 |
26 | # %%
27 |
28 | batch_size = 256
29 | store_topk_activating = 100
30 |
31 | unique_num_experts = duplicate_features['num_experts'].unique()
32 | unique_topks = duplicate_features['k'].unique()
33 |
34 | ae_details_to_top_activating = {}
35 |
36 |
37 | for num_experts in unique_num_experts:
38 | for k in tqdm(unique_topks):
39 | filtered_df = duplicate_features[(duplicate_features['num_experts'] == num_experts) & (duplicate_features['k'] == k)]
40 | feature_index_to_num_dupes = filtered_df.set_index('feature_index')['num_dupes'].to_dict()
41 | feature_index_tensor = torch.tensor(list(feature_index_to_num_dupes.keys()), device=device)
42 | feature_index_to_topk_activating = {feature_index: [] for feature_index in feature_index_to_num_dupes.keys()}
43 |
44 | if feature_index_tensor.numel() == 0:
45 | ae_details_to_top_activating[(num_experts, k)] = {}
46 | continue
47 |
48 | if num_experts == 1:
49 | ae = AutoEncoderTopK.from_pretrained(f"../dictionaries/topk/k{k}/ae.pt", k=k, device=device)
50 | else:
51 | ae = SwitchAutoEncoder.from_pretrained(f"../dictionaries/fixed-width/{num_experts}_experts/k{k}/ae.pt", k=k, experts=num_experts, device=device)
52 |
53 | key = (num_experts, k)
54 |
55 | for batch_start in range(0, len(data), batch_size):
56 | batch = data[batch_start:batch_start+batch_size].to(device)
57 | batch_tokens = tokens[batch_start:batch_start+batch_size].to(device)
58 |
59 | top_acts, top_indices = ae.forward(batch, output_features="all")[2:]
60 | activations = torch.zeros(batch_tokens.shape[0] * batch_tokens.shape[1], len(ae.decoder), device=device)
61 |
62 | activations.scatter_(1, top_indices, top_acts)
63 | dupe_feature_activations = activations[..., feature_index_tensor]
64 | top_activating = torch.topk(dupe_feature_activations, store_topk_activating, dim=0)
65 | for j, feature_index in enumerate(feature_index_tensor):
66 | top_activating_indices = top_activating.indices[:, j]
67 | top_activating_values = top_activating.values[:, j]
68 | top_activating_contexts = top_activating_indices // ctx_len + batch_start
69 | top_activating_token_ids = top_activating_indices % ctx_len
70 | feature_index_to_topk_activating[feature_index.item()].extend(list(zip(top_activating_contexts.cpu().numpy(), top_activating_token_ids.cpu().numpy(), top_activating_values.cpu().numpy())))
71 |
72 | # Get the top k activations for each feature index
73 | for feature_index, top_activating in feature_index_to_topk_activating.items():
74 | num_dupes = feature_index_to_num_dupes[feature_index]
75 | top_activating.sort(key=lambda x: x[2], reverse=True)
76 | feature_index_to_topk_activating[feature_index] = (top_activating[:store_topk_activating], num_dupes)
77 |
78 | ae_details_to_top_activating[key] = feature_index_to_topk_activating
79 |
80 | torch.save(ae_details_to_top_activating, save_location)
81 |
82 | # %%
83 |
84 | # Load the top activating for duplicates
85 | ae_details_to_top_activating = torch.load(save_location, map_location=device)
86 |
87 | context_limit = 15
88 |
89 | # Load GPT2 tokenizer
90 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
91 |
92 | for (num_experts, k), feature_index_to_topk_activating in tqdm(ae_details_to_top_activating.items()):
93 | output_file = f"../data/top_activating_for_dupes_layer8_num-experts{num_experts}_topk{k}.txt"
94 |
95 | if num_experts == 1:
96 | ae = AutoEncoderTopK.from_pretrained(f"../dictionaries/topk/k{k}/ae.pt", k=k, device=device)
97 | else:
98 | ae = SwitchAutoEncoder.from_pretrained(f"../dictionaries/fixed-width/{num_experts}_experts/k{k}/ae.pt", k=k, experts=num_experts, device=device)
99 |
100 | decoder_vecs = ae.decoder
101 | decoder_vecs = decoder_vecs / decoder_vecs.norm(dim=-1, keepdim=True)
102 | sims = decoder_vecs @ decoder_vecs.T
103 | threshold = 0.9
104 | # Zero diagonal of sims
105 | sims.fill_diagonal_(0)
106 |
107 | encoder_vecs = ae.encoder.weight
108 | encoder_vecs = encoder_vecs / encoder_vecs.norm(dim=-1, keepdim=True)
109 | encoder_sims = encoder_vecs @ encoder_vecs.T
110 |
111 | with open(output_file, 'w') as f:
112 | f.write(f"--------------- NUM EXPERTS = {num_experts}, K = {k} ---------------\n\n\n")
113 | for feature_index, (top_activating, num_dupes) in feature_index_to_topk_activating.items():
114 | similar = torch.nonzero(sims[feature_index] > threshold).flatten().cpu().numpy()
115 | nonzero_sims = sims[feature_index][similar].cpu().numpy()
116 | this_encoder_sims = encoder_sims[feature_index][similar].cpu().numpy()
117 | f.write(f"feature {feature_index} with {num_dupes} dupes: {similar}, {nonzero_sims}, {this_encoder_sims}\n")
118 | for context_index, token_index, value in top_activating:
119 | context_start = max(0, token_index - context_limit)
120 | context_end = min(token_index, ctx_len)
121 | token_context = tokens[context_index][context_start:context_end]
122 | token_context_strs = [tokenizer.decode(token) for token in token_context]
123 | context = "".join(token_context_strs)
124 | context = context.replace("\n", " ")
125 | if context_end + 1 < ctx_len:
126 | next_token = tokenizer.decode(tokens[context_index][context_end: context_end + 1][0])
127 | next_token = next_token.replace("\n", "")
128 | else:
129 | next_token = ""
130 | f.write(f"{value:.4f} {next_token}\n{context}\n")
131 | f.write("\n")
132 | f.write("\n")
133 | # %%
134 |
--------------------------------------------------------------------------------
/compare_geometry/find_duplicate_features.py:
--------------------------------------------------------------------------------
1 | # %%
2 |
3 | import os
4 | import sys
5 | sys.path.insert(1, os.path.join(sys.path[0], '..'))
6 |
7 | from dictionary_learning.trainers.switch import SwitchAutoEncoder
8 | from dictionary_learning.trainers.top_k import AutoEncoderTopK
9 | import einops
10 | import matplotlib.pyplot as plt
11 | import torch
12 | import numpy as np
13 | from tqdm import tqdm
14 | import os
15 | import pandas as pd
16 |
17 |
18 | os.makedirs("plots/compare_geometry", exist_ok=True)
19 |
20 | torch.set_grad_enabled(False)
21 |
22 |
23 | experts = [1, 16, 32, 64, 128]
24 | ks = [8, 16, 32, 48, 64, 96, 128, 192]
25 |
26 | device = "cuda:1"
27 | # device = "cpu"
28 |
29 | threshold = 0.9
30 |
31 |
32 | intra_sae_max_sims = {}
33 |
34 | fig, axs = plt.subplots(len(experts), len(ks), figsize=(20, 10))
35 |
36 | data = []
37 |
38 | for i, num_experts in enumerate(experts):
39 |
40 | for j, k in enumerate(tqdm(ks)):
41 | ax = axs[i, j]
42 |
43 | if num_experts == 1:
44 | ae = AutoEncoderTopK.from_pretrained(f"../dictionaries/topk/k{k}/ae.pt", k=k, device=device)
45 | else:
46 | ae = SwitchAutoEncoder.from_pretrained(f"../dictionaries/fixed-width/{num_experts}_experts/k{k}/ae.pt", k=k, experts=num_experts, device=device)
47 |
48 | normalized_weights = ae.decoder.data / ae.decoder.data.norm(dim=-1, keepdim=True)
49 |
50 | sims = normalized_weights @ normalized_weights.T
51 |
52 | sims.fill_diagonal_(0)
53 |
54 | num_dupes = (sims > threshold).sum(dim=-1).cpu().numpy()
55 |
56 | ax.hist([i for i in num_dupes if i != 0], bins=100)
57 |
58 | ax.set_title(f"{num_experts} experts, k={k}")
59 |
60 | for feature_index, num_dupes in enumerate(num_dupes):
61 | if num_dupes == 0:
62 | continue
63 | data.append((num_experts, k, feature_index, num_dupes))
64 |
65 | fig.suptitle(f"Number of duplicates per feature, threshold={threshold}")
66 |
67 | fig.supxlabel("Number of duplicates")
68 |
69 | fig.supylabel("Number of features")
70 |
71 | plt.tight_layout()
72 |
73 | plt.savefig("plots/compare_geometry/num_duplicates_per_feature_fixed_width.pdf")
74 |
75 | df = pd.DataFrame(data, columns=["num_experts", "k", "feature_index", "num_dupes"])
76 |
77 | # Sort by num_experts, then k, then num_dupes in descending order
78 | df = df.sort_values(by=["num_experts", "k", "num_dupes"], ascending=[True, True, False])
79 |
80 | os.makedirs("../data", exist_ok=True)
81 | df.to_csv("../data/duplicates.csv", index=False)
82 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | lm = 'openai-community/gpt2'
2 | activation_dim = 768
3 | layer = 8
4 | hf = 'Skylion007/openwebtext'
5 | steps = 100_000
6 | n_ctxs = int(1e4)
7 |
--------------------------------------------------------------------------------
/dictionary_learning/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/dictionary_learning/.DS_Store
--------------------------------------------------------------------------------
/dictionary_learning/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 | dictionaries
162 | wandb
163 | experiment*
164 | run_experiment.sh
165 | nohup.out
--------------------------------------------------------------------------------
/dictionary_learning/__init__.py:
--------------------------------------------------------------------------------
1 | from .dictionary import AutoEncoder, GatedAutoEncoder, JumpReluAutoEncoder
2 | from .buffer import ActivationBuffer
--------------------------------------------------------------------------------
/dictionary_learning/config.py:
--------------------------------------------------------------------------------
1 | # debugging flag for use in other scripts
2 | DEBUG = False
--------------------------------------------------------------------------------
/dictionary_learning/evaluation.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for evaluating dictionaries on a model and dataset.
3 | """
4 |
5 | import torch as t
6 | from .buffer import ActivationBuffer, NNsightActivationBuffer
7 | from nnsight import LanguageModel
8 | from .config import DEBUG
9 |
10 |
11 | def loss_recovered(
12 | text, # a batch of text
13 | model: LanguageModel, # an nnsight LanguageModel
14 | submodule, # submodules of model
15 | dictionary, # dictionaries for submodules
16 | max_len=None, # max context length for loss recovered
17 | normalize_batch=False, # normalize batch before passing through dictionary
18 | io="out", # can be 'in', 'out', or 'in_and_out'
19 | tracer_args = {'use_cache': False, 'output_attentions': False}, # minimize cache during model trace.
20 | ):
21 | """
22 | How much of the model's loss is recovered by replacing the component output
23 | with the reconstruction by the autoencoder?
24 | """
25 |
26 | if max_len is None:
27 | invoker_args = {}
28 | else:
29 | invoker_args = {"truncation": True, "max_length": max_len }
30 |
31 | # unmodified logits
32 | with model.trace(text, invoker_args=invoker_args):
33 | logits_original = model.output.save()
34 | logits_original = logits_original.value
35 |
36 | # logits when replacing component activations with reconstruction by autoencoder
37 | with model.trace(text, **tracer_args, invoker_args=invoker_args):
38 | if io == 'in':
39 | x = submodule.input[0]
40 | if type(submodule.input.shape) == tuple: x = x[0]
41 | if normalize_batch:
42 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
43 | x = x * scale
44 | elif io == 'out':
45 | x = submodule.output
46 | if type(submodule.output.shape) == tuple: x = x[0]
47 | if normalize_batch:
48 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
49 | x = x * scale
50 | elif io == 'in_and_out':
51 | x = submodule.input[0]
52 | if type(submodule.input.shape) == tuple: x = x[0]
53 | print(f'x.shape: {x.shape}')
54 | if normalize_batch:
55 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
56 | x = x * scale
57 | else:
58 | raise ValueError(f"Invalid value for io: {io}")
59 | x = x.save()
60 |
61 | # pull this out so dictionary can be written without FakeTensor (top_k needs this)
62 | x_hat = dictionary(x.view(-1, x.shape[-1])).view(x.shape)
63 |
64 | # intervene with `x_hat`
65 | with model.trace(text, **tracer_args, invoker_args=invoker_args):
66 | if io == 'in':
67 | x = submodule.input[0]
68 | if normalize_batch:
69 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
70 | x_hat = x_hat / scale
71 | if type(submodule.input.shape) == tuple:
72 | submodule.input[0][:] = x_hat
73 | else:
74 | submodule.input = x_hat
75 | elif io == 'out':
76 | x = submodule.output
77 | if normalize_batch:
78 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
79 | x_hat = x_hat / scale
80 | if type(submodule.output.shape) == tuple:
81 | submodule.output = (x_hat,)
82 | else:
83 | submodule.output = x_hat
84 | elif io == 'in_and_out':
85 | x = submodule.input[0]
86 | if normalize_batch:
87 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
88 | x_hat = x_hat / scale
89 | submodule.output = x_hat
90 | else:
91 | raise ValueError(f"Invalid value for io: {io}")
92 |
93 | logits_reconstructed = model.output.save()
94 | logits_reconstructed = logits_reconstructed.value
95 |
96 | # logits when replacing component activations with zeros
97 | with model.trace(text, **tracer_args, invoker_args=invoker_args):
98 | if io == 'in':
99 | x = submodule.input[0]
100 | if type(submodule.input.shape) == tuple:
101 | submodule.input[0][:] = t.zeros_like(x[0])
102 | else:
103 | submodule.input = t.zeros_like(x)
104 | elif io in ['out', 'in_and_out']:
105 | x = submodule.output
106 | if type(submodule.output.shape) == tuple:
107 | submodule.output[0][:] = t.zeros_like(x[0])
108 | else:
109 | submodule.output = t.zeros_like(x)
110 | else:
111 | raise ValueError(f"Invalid value for io: {io}")
112 |
113 | input = model.input.save()
114 | logits_zero = model.output.save()
115 | logits_zero = logits_zero.value
116 |
117 | # get everything into the right format
118 | try:
119 | logits_original = logits_original.logits
120 | logits_reconstructed = logits_reconstructed.logits
121 | logits_zero = logits_zero.logits
122 | except:
123 | pass
124 |
125 | if isinstance(text, t.Tensor):
126 | tokens = text
127 | else:
128 | try:
129 | tokens = input[1]['input_ids']
130 | except:
131 | tokens = input[1]['input']
132 |
133 | # compute losses
134 | losses = []
135 | if hasattr(model, 'tokenizer') and model.tokenizer is not None:
136 | loss_kwargs = {'ignore_index': model.tokenizer.pad_token_id}
137 | else:
138 | loss_kwargs = {}
139 | for logits in [logits_original, logits_reconstructed, logits_zero]:
140 | loss = t.nn.CrossEntropyLoss(**loss_kwargs)(
141 | logits[:, :-1, :].reshape(-1, logits.shape[-1]), tokens[:, 1:].reshape(-1)
142 | )
143 | losses.append(loss)
144 |
145 | return tuple(losses)
146 |
147 |
148 | def evaluate(
149 | dictionary, # a dictionary
150 | activations, # a generator of activations; if an ActivationBuffer, also compute loss recovered
151 | max_len=128, # max context length for loss recovered
152 | batch_size=128, # batch size for loss recovered
153 | io="out", # can be 'in', 'out', or 'in_and_out'
154 | normalize_batch=False, # normalize batch before passing through dictionary
155 | tracer_args={'use_cache': False, 'output_attentions': False}, # minimize cache during model trace.
156 | device="cpu",
157 | ):
158 | with t.no_grad():
159 |
160 | out = {} # dict of results
161 |
162 | try:
163 | x = next(activations).to(device)
164 | if normalize_batch:
165 | x = x / x.norm(dim=-1).mean() * (dictionary.activation_dim ** 0.5)
166 |
167 | except StopIteration:
168 | raise StopIteration(
169 | "Not enough activations in buffer. Pass a buffer with a smaller batch size or more data."
170 | )
171 |
172 | x_hat, f = dictionary(x, output_features=True)
173 | l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean()
174 | mse_loss = (x - x_hat).pow(2).sum(dim=-1).mean()
175 | l1_loss = f.norm(p=1, dim=-1).mean()
176 | l0 = (f != 0).float().sum(dim=-1).mean()
177 | frac_alive = t.flatten(f, start_dim=0, end_dim=1).any(dim=0).sum() / dictionary.dict_size
178 |
179 | # cosine similarity between x and x_hat
180 | x_normed = x / t.linalg.norm(x, dim=-1, keepdim=True)
181 | x_hat_normed = x_hat / t.linalg.norm(x_hat, dim=-1, keepdim=True)
182 | cossim = (x_normed * x_hat_normed).sum(dim=-1).mean()
183 |
184 | # l2 ratio
185 | l2_ratio = (t.linalg.norm(x_hat, dim=-1) / t.linalg.norm(x, dim=-1)).mean()
186 |
187 | #compute variance explained
188 | total_variance = t.var(x, dim=0).sum()
189 | residual_variance = t.var(x - x_hat, dim=0).sum()
190 | frac_variance_explained = (1 - residual_variance / total_variance)
191 |
192 | out["l2_loss"] = l2_loss.item()
193 | out["l1_loss"] = l1_loss.item()
194 | out["mse_loss"] = mse_loss.item()
195 | out["l0"] = l0.item()
196 | out["frac_alive"] = frac_alive.item()
197 | out["frac_variance_explained"] = frac_variance_explained.item()
198 | out["cossim"] = cossim.item()
199 | out["l2_ratio"] = l2_ratio.item()
200 |
201 | if not isinstance(activations, (ActivationBuffer, NNsightActivationBuffer)):
202 | return out
203 |
204 | # compute loss recovered
205 | loss_original, loss_reconstructed, loss_zero = loss_recovered(
206 | activations.text_batch(batch_size=batch_size),
207 | activations.model,
208 | activations.submodule,
209 | dictionary,
210 | max_len=max_len,
211 | normalize_batch=normalize_batch,
212 | io=io,
213 | tracer_args=tracer_args
214 | )
215 | frac_recovered = (loss_reconstructed - loss_zero) / (loss_original - loss_zero)
216 |
217 | out["loss_original"] = loss_original.item()
218 | out["loss_reconstructed"] = loss_reconstructed.item()
219 | out["loss_zero"] = loss_zero.item()
220 | out["frac_recovered"] = frac_recovered.item()
221 |
222 | return out
--------------------------------------------------------------------------------
/dictionary_learning/grad_pursuit.py:
--------------------------------------------------------------------------------
1 | """
2 | Implements batched gradient pursuit algorithm here:
3 | https://www.lesswrong.com/posts/C5KAZQib3bzzpeyrg/full-post-progress-update-1-from-the-gdm-mech-interp-team#Inference_Time_Optimisation:~:text=two%20seem%20promising.-,Details%20of%20Sparse%20Approximation%20Algorithms%20(for%20accelerators),-This%20section%20gets
4 | """
5 |
6 | import torch as t
7 |
8 |
9 | def _grad_pursuit_update_step(signal, weights, dictionary, batch_arange, selected_features):
10 | """
11 | signal: b x d, weights: b x n, dictionary: d x n, batch_arange: b, selected_features: b x n
12 | """
13 | residual = signal - t.einsum('bn,dn -> bd', weights, dictionary)
14 | # choose the element with largest inner product with residual, as in matched pursuit.
15 | inner_products = t.einsum('dn,bd -> bn', dictionary, residual)
16 | idxs = t.argmax(inner_products, dim=1)
17 | # add the new feature to the active set.
18 | selected_features[batch_arange, idxs] = 1
19 |
20 | # the gradient for the weights is the inner product, restricted to the chosen features
21 | grad = selected_features * inner_products
22 | # the next two steps compute the optimal step size
23 | c = t.einsum('bn,dn -> bd', grad, dictionary)
24 | step_size = t.einsum('bd,bd -> b', c, residual) / t.einsum('bd,bd -> b ', c, c)
25 | weights = weights + t.einsum('b,bn -> bn', step_size, grad)
26 | weights = t.clip(weights, min=0) # clip the weights to be positive
27 | return weights, selected_features
28 |
29 | def grad_pursuit(signal, dictionary, target_l0 : int = 20, device : str = 'cpu'):
30 | """
31 | Inputs: signal: b x d, dictionary: d x n, target_l0: int, device: str
32 | Outputs: weights: b x n
33 | """
34 | assert len(signal.shape) == 2 # makes sure this a batch of signals
35 | with t.no_grad():
36 | batch_arange = t.arange(signal.shape[0]).to(device)
37 | weights = t.zeros((signal.shape[0], dictionary.shape[1])).to(device)
38 | selected_features = t.zeros((signal.shape[0], dictionary.shape[1])).to(device)
39 | for _ in range(target_l0):
40 | weights, selected_features = _grad_pursuit_update_step(
41 | signal, weights, dictionary, batch_arange, selected_features)
42 | return weights
--------------------------------------------------------------------------------
/dictionary_learning/interp.py:
--------------------------------------------------------------------------------
1 | import random
2 | from circuitsvis.activations import text_neuron_activations
3 | from einops import rearrange
4 | import torch as t
5 | from collections import namedtuple
6 | import umap
7 | import pandas as pd
8 | import plotly.express as px
9 |
10 | TRACER_KWARGS = {"scan" : False, "validate" : False}
11 |
12 | def feature_effect(
13 | model,
14 | submodule,
15 | dictionary,
16 | feature,
17 | inputs,
18 | max_length=128,
19 | add_residual=True, # whether to compensate for dictionary reconstruction error by adding residual
20 | k=10,
21 | largest=True,
22 | ):
23 | """
24 | Effect of ablating the feature on top k predictions for next token.
25 | """
26 | # clean run
27 | with t.no_grad(), model.trace(inputs, invoker_args=dict(max_length=max_length, truncation=True)):
28 | if dictionary is None:
29 | pass
30 | elif not add_residual: # run hidden state through autoencoder
31 | if type(submodule.output.shape) == tuple:
32 | submodule.output[0][:] = dictionary(submodule.output[0])
33 | else:
34 | submodule.output = dictionary(submodule.output)
35 | clean_output = model.output.save()
36 | try:
37 | clean_logits = clean_output.value.logits[:, -1, :]
38 | except:
39 | clean_logits = clean_output.value[:, -1, :]
40 | clean_logprobs = t.nn.functional.log_softmax(clean_logits, dim=-1)
41 |
42 | # ablated run
43 | with t.no_grad(), model.trace(inputs, invoker_args=dict(max_length=max_length, truncation=True)):
44 | if dictionary is None:
45 | if type(submodule.output.shape) == tuple:
46 | submodule.output[0][:, -1, feature] = 0
47 | else:
48 | submodule.output[:, -1, feature] = 0
49 | else:
50 | x = submodule.output
51 | if type(x.shape) == tuple:
52 | x = x[0]
53 | x_hat, f = dictionary(x, output_features=True)
54 | residual = x - x_hat
55 |
56 | f[:, -1, feature] = 0
57 | if add_residual:
58 | x_hat = dictionary.decode(f) + residual
59 | else:
60 | x_hat = dictionary.decode(f)
61 |
62 | if type(submodule.output.shape) == tuple:
63 | submodule.output[0][:] = x_hat
64 | else:
65 | submodule.output = x_hat
66 | ablated_output = model.output.save()
67 | try:
68 | ablated_logits = ablated_output.value.logits[:, -1, :]
69 | except:
70 | ablated_logits = ablated_output.value[:, -1, :]
71 | ablated_logprobs = t.nn.functional.log_softmax(ablated_logits, dim=-1)
72 |
73 | diff = clean_logprobs - ablated_logprobs
74 | top_probs, top_tokens = t.topk(diff.mean(dim=0), k=k, largest=largest)
75 | return top_tokens, top_probs
76 |
77 |
78 | def examine_dimension(model, submodule, buffer, dictionary=None, max_length=128, n_inputs=512,
79 | dim_idx=None, k=30):
80 |
81 | def _list_decode(x):
82 | if isinstance(x, int):
83 | return model.tokenizer.decode(x)
84 | else:
85 | return [_list_decode(y) for y in x]
86 |
87 | if dim_idx is None:
88 | dim_idx = random.randint(0, activations.shape[-1]-1)
89 |
90 | inputs = buffer.text_batch(batch_size=n_inputs)
91 | with t.no_grad(), model.trace(inputs, invoker_args=dict(max_length=max_length, truncation=True)):
92 | tokens = model.input[1]['input_ids'].save() # if you're getting errors, check here; might only work for pythia models
93 | activations = submodule.output
94 | if type(activations.shape) == tuple:
95 | activations = activations[0]
96 | if dictionary is not None:
97 | activations = dictionary.encode(activations)
98 | activations = activations[:,:, dim_idx].save()
99 | activations = activations.value
100 |
101 | # get top k tokens by mean activation
102 | tokens = tokens.value
103 | token_mean_acts = {}
104 | for ctx in tokens:
105 | for tok in ctx:
106 | if tok.item() in token_mean_acts:
107 | continue
108 | idxs = (tokens == tok).nonzero(as_tuple=True)
109 | token_mean_acts[tok.item()] = activations[idxs].mean().item()
110 | top_tokens = sorted(token_mean_acts.items(), key=lambda x: x[1], reverse=True)[:k]
111 | top_tokens = [(model.tokenizer.decode(tok), act) for tok, act in top_tokens]
112 |
113 | flattened_acts = rearrange(activations, 'b n -> (b n)')
114 | topk_indices = t.argsort(flattened_acts, dim=0, descending=True)[:k]
115 | batch_indices = topk_indices // activations.shape[1]
116 | token_indices = topk_indices % activations.shape[1]
117 | tokens = [
118 | tokens[batch_idx, :token_idx+1].tolist() for batch_idx, token_idx in zip(batch_indices, token_indices)
119 | ]
120 | activations = [
121 | activations[batch_idx, :token_id+1, None, None] for batch_idx, token_id in zip(batch_indices, token_indices)
122 | ]
123 | decoded_tokens = _list_decode(tokens)
124 | top_contexts = text_neuron_activations(decoded_tokens, activations)
125 |
126 | top_affected = feature_effect(
127 | model,
128 | submodule,
129 | dictionary,
130 | dim_idx,
131 | tokens,
132 | max_length=max_length,
133 | k=k
134 | )
135 | top_affected = [(model.tokenizer.decode(tok), prob.item()) for tok, prob in zip(*top_affected)]
136 |
137 | return namedtuple('featureProfile', ['top_contexts', 'top_tokens', 'top_affected'])(top_contexts, top_tokens, top_affected)
138 |
139 | def feature_umap(
140 | dictionary,
141 | weight='decoder', # 'encoder' or 'decoder'
142 | # UMAP parameters
143 | n_neighbors=15,
144 | metric='cosine',
145 | min_dist=0.05,
146 | n_components=2, # dimension of the UMAP embedding
147 | feat_idxs=None, # if not none, indicate the feature with a red dot
148 | ):
149 | """
150 | Fit a UMAP embedding of the dictionary features and return a plotly plot of the result."""
151 | if weight == 'encoder':
152 | df = pd.DataFrame(dictionary.encoder.weight.cpu().detach().numpy())
153 | else:
154 | df = pd.DataFrame(dictionary.decoder.weight.T.cpu().detach().numpy())
155 | reducer = umap.UMAP(
156 | n_neighbors=n_neighbors,
157 | metric=metric,
158 | min_dist=min_dist,
159 | n_components=n_components,
160 | )
161 | embedding = reducer.fit_transform(df)
162 | if feat_idxs is None:
163 | colors = None
164 | if isinstance(feat_idxs, int):
165 | feat_idxs = [feat_idxs]
166 | else:
167 | colors = ['blue' if i not in feat_idxs else 'red' for i in range(embedding.shape[0])]
168 | if n_components == 2:
169 | return px.scatter(x=embedding[:, 0], y=embedding[:, 1], hover_name=df.index, color=colors)
170 | if n_components == 3:
171 | return px.scatter_3d(x=embedding[:, 0], y=embedding[:, 1], z=embedding[:, 2], hover_name=df.index, color=colors)
172 | raise ValueError("n_components must be 2 or 3")
--------------------------------------------------------------------------------
/dictionary_learning/pretrained_dictionary_downloader.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Base URL for downloads - replace this with your actual URL, excluding the /aX part
4 | BASE_URL="https://baulab.us/u/smarks/autoencoders/pythia-70m-deduped"
5 |
6 | # Local directory where you want to replicate the folder structure, now including the specified root directory
7 | LOCAL_DIR="dictionaries/pythia-70m-deduped"
8 |
9 | # Default 'a' values array
10 | declare -a default_a_values=("attn_out_layerX" "mlp_out_layerX" "resid_out_layerX") # Removed "embed" from default handling
11 |
12 | # 'c' values array - Initially without "checkpoints"
13 | declare -a c_values=("ae.pt" "config.json")
14 |
15 | # Name of the set of autoencoders
16 | sae_set_name="10_32768"
17 |
18 | # Checkpoints flag variable, default is 0 (don't download checkpoints)
19 | download_checkpoints=0
20 |
21 | # Custom layers to download
22 | declare -a custom_layers=("0" "1" "2" "3" "4" "5" "embed")
23 |
24 | # Parse flags
25 | while [[ "$#" -gt 0 ]]; do
26 | case $1 in
27 | -c|--checkpoints) download_checkpoints=1 ;;
28 | --layers) IFS=',' read -ra custom_layers <<< "$2"; shift ;;
29 | *) echo "Unknown parameter passed: $1"; exit 1 ;;
30 | esac
31 | shift
32 | done
33 |
34 | # Ensure the root download directory exists
35 | mkdir -p "${LOCAL_DIR}"
36 |
37 | # Include checkpoints if flag is set
38 | if [[ $download_checkpoints -eq 1 ]]; then
39 | c_values+=("checkpoints")
40 | fi
41 |
42 | # Prepare 'a' values based on custom_layers input
43 | declare -a a_values=()
44 | if [[ ${#custom_layers[@]} -eq 0 ]]; then
45 | a_values=("${default_a_values[@]}")
46 | else
47 | for layer in "${custom_layers[@]}"; do
48 | if [[ $layer == "embed" ]]; then
49 | a_values+=("embed")
50 | else
51 | for a_value in "${default_a_values[@]}"; do
52 | a_values+=("${a_value/X/$layer}")
53 | done
54 | fi
55 | done
56 | fi
57 |
58 | # Download logic
59 | for a_value in "${a_values[@]}"; do
60 | for c in "${c_values[@]}"; do
61 | DOWNLOAD_URL="${BASE_URL}/${a_value}/${sae_set_name}/${c}"
62 | LOCAL_PATH="${LOCAL_DIR}/${a_value}/${sae_set_name}/${c}"
63 | if [ "${c}" == "checkpoints" ]; then
64 | # Special handling for downloading checkpoints as folders
65 | mkdir -p "${LOCAL_PATH}"
66 | wget -r -np -nH --cut-dirs=7 -P "${LOCAL_PATH}" --accept "*.pt" "${DOWNLOAD_URL}/"
67 |
68 | else
69 | # Handle all other files
70 | mkdir -p "$(dirname "${LOCAL_PATH}")"
71 | wget -P "$(dirname "${LOCAL_PATH}")" "${DOWNLOAD_URL}"
72 | fi
73 | done
74 | done
75 |
76 | echo "Download completed."
77 |
--------------------------------------------------------------------------------
/dictionary_learning/requirements.txt:
--------------------------------------------------------------------------------
1 | circuitsvis>=1.43.2
2 | datasets>=2.18.0
3 | einops>=0.7.0
4 | matplotlib>=3.8.3
5 | nnsight>=0.2.11
6 | pandas>=2.2.1
7 | plotly>=5.18.0
8 | torch>=2.1.2
9 | tqdm>=4.66.1
10 | umap-learn>=0.5.6
11 | zstandard>=0.22.0
12 | wandb
13 |
--------------------------------------------------------------------------------
/dictionary_learning/trainers/gdm.py:
--------------------------------------------------------------------------------
1 | """
2 | Implements the training scheme for a gated SAE described in https://arxiv.org/abs/2404.16014
3 | """
4 |
5 | import torch as t
6 | from ..trainers.trainer import SAETrainer
7 | from ..config import DEBUG
8 | from ..dictionary import GatedAutoEncoder
9 | from collections import namedtuple
10 |
11 | class ConstrainedAdam(t.optim.Adam):
12 | """
13 | A variant of Adam where some of the parameters are constrained to have unit norm.
14 | """
15 | def __init__(self, params, constrained_params, lr):
16 | super().__init__(params, lr=lr, betas=(0, 0.999))
17 | self.constrained_params = list(constrained_params)
18 |
19 | def step(self, closure=None):
20 | with t.no_grad():
21 | for p in self.constrained_params:
22 | normed_p = p / p.norm(dim=0, keepdim=True)
23 | # project away the parallel component of the gradient
24 | p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p
25 | super().step(closure=closure)
26 | with t.no_grad():
27 | for p in self.constrained_params:
28 | # renormalize the constrained parameters
29 | p /= p.norm(dim=0, keepdim=True)
30 |
31 | class GatedSAETrainer(SAETrainer):
32 | """
33 | Gated SAE training scheme.
34 | """
35 | def __init__(self,
36 | dict_class=GatedAutoEncoder,
37 | activation_dim=512,
38 | dict_size=64*512,
39 | lr=5e-5,
40 | l1_penalty=1e-1,
41 | warmup_steps=1000, # lr warmup period at start of training and after each resample
42 | resample_steps=None, # how often to resample neurons
43 | seed=None,
44 | device=None,
45 | layer=None,
46 | lm_name=None,
47 | wandb_name='GatedSAETrainer',
48 | ):
49 | super().__init__(seed)
50 |
51 | assert layer is not None and lm_name is not None
52 | self.layer = layer
53 | self.lm_name = lm_name
54 |
55 | if seed is not None:
56 | t.manual_seed(seed)
57 | t.cuda.manual_seed_all(seed)
58 |
59 | # initialize dictionary
60 | self.ae = dict_class(activation_dim, dict_size)
61 |
62 | self.lr = lr
63 | self.l1_penalty=l1_penalty
64 | self.warmup_steps = warmup_steps
65 | self.wandb_name = wandb_name
66 |
67 | if device is None:
68 | self.device = 'cuda' if t.cuda.is_available() else 'cpu'
69 | else:
70 | self.device = device
71 | self.ae.to(self.device)
72 |
73 | self.optimizer = ConstrainedAdam(
74 | self.ae.parameters(),
75 | self.ae.decoder.parameters(),
76 | lr=lr
77 | )
78 | def warmup_fn(step):
79 | return min(1, step / warmup_steps)
80 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, warmup_fn)
81 |
82 | def loss(self, x, logging=False, **kwargs):
83 | f, f_gate = self.ae.encode(x, return_gate=True)
84 | x_hat = self.ae.decode(f)
85 | x_hat_gate = f_gate @ self.ae.decoder.weight.detach().T + self.ae.decoder_bias.detach()
86 |
87 | L_recon = (x - x_hat).pow(2).sum(dim=-1).mean()
88 | L_sparse = t.linalg.norm(f_gate, ord=1, dim=-1).mean()
89 | L_aux = (x - x_hat_gate).pow(2).sum(dim=-1).mean()
90 |
91 | loss = L_recon + self.l1_penalty * L_sparse + L_aux
92 |
93 | if not logging:
94 | return loss
95 | else:
96 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])(
97 | x, x_hat, f,
98 | {
99 | 'mse_loss' : L_recon.item(),
100 | 'sparsity_loss' : L_sparse.item(),
101 | 'aux_loss' : L_aux.item(),
102 | 'loss' : loss.item()
103 | }
104 | )
105 |
106 | def update(self, step, x):
107 | x = x.to(self.device)
108 | self.optimizer.zero_grad()
109 | loss = self.loss(x)
110 | loss.backward()
111 | self.optimizer.step()
112 | self.scheduler.step()
113 |
114 | @property
115 | def config(self):
116 | return {
117 | 'trainer_class' : 'GatedSAETrainer',
118 | 'activation_dim' : self.ae.activation_dim,
119 | 'dict_size' : self.ae.dict_size,
120 | 'lr' : self.lr,
121 | 'l1_penalty' : self.l1_penalty,
122 | 'warmup_steps' : self.warmup_steps,
123 | 'device' : self.device,
124 | 'layer' : self.layer,
125 | 'lm_name' : self.lm_name,
126 | 'wandb_name': self.wandb_name,
127 | }
128 |
--------------------------------------------------------------------------------
/dictionary_learning/trainers/jump_relu.py:
--------------------------------------------------------------------------------
1 | import torch as t
2 | import torch.autograd as autograd
3 | from ..trainers.trainer import SAETrainer
4 | from ..dictionary import JumpReluAutoEncoder, StepFunction
5 |
6 | from ..config import DEBUG
7 | from collections import namedtuple
8 |
9 | class JumpReluTrainer(SAETrainer):
10 | def __init__(self,
11 | dict_class=JumpReluAutoEncoder,
12 | activation_dim=512,
13 | dict_size=64*512,
14 | lr=5e-5,
15 | l0_penalty=1e-1,
16 | warmup_steps=1000, # lr warmup period at start of training and after each resample
17 | seed=None,
18 | device=None,
19 | layer=None,
20 | lm_name=None,
21 | wandb_name='JumpReluTrainer',
22 | submodule_name=None,
23 | set_linear_to_constant=False,
24 | pre_encoder_bias=True,
25 | ):
26 | super().__init__(seed)
27 | if seed is not None:
28 | t.manual_seed(seed)
29 | t.cuda.manual_seed_all(seed)
30 |
31 | assert layer is not None and lm_name is not None
32 | self.layer = layer
33 | self.lm_name = lm_name
34 | self.submodule_name = submodule_name
35 | self.wandb_name = wandb_name
36 |
37 | if device is None:
38 | self.device = 'cuda' if t.cuda.is_available() else 'cpu'
39 | else:
40 | self.device = device
41 |
42 | self.ae = dict_class(activation_dim, dict_size, pre_encoder_bias=pre_encoder_bias)
43 | self.ae.to(self.device)
44 | self.lr = lr
45 | self.warmup_steps = warmup_steps
46 | self.l0_penalty = l0_penalty
47 | self.set_linear_to_constant = set_linear_to_constant
48 |
49 | self.optimizer = t.optim.Adam(self.ae.parameters(), betas=(0, 0.999), eps=1e-8)
50 | def warmup_fn(step):
51 | return min(1, step / warmup_steps)
52 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, warmup_fn)
53 |
54 | def loss(self, x, logging=False, **kwargs):
55 | x_hat, f = self.ae(x, output_features=True, set_linear_to_constant=self.set_linear_to_constant)
56 | L_recon = (x - x_hat).pow(2).sum(dim=-1).mean()
57 | L_spars = StepFunction.apply(f, self.ae.threshold, self.ae.bandwidth).sum(dim=-1).mean()
58 |
59 | loss = L_recon + self.l0_penalty * L_spars
60 |
61 | if not logging:
62 | return loss
63 | else:
64 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])(
65 | x, x_hat, f,
66 | losses={
67 | 'mse_loss' : L_recon.item(),
68 | 'sparsity_loss' : L_spars.item(),
69 | 'loss' : loss.item()
70 | }
71 | )
72 |
73 | def update(self, step, x):
74 | x = x.to(self.device)
75 | self.optimizer.zero_grad()
76 | loss = self.loss(x)
77 | loss.backward()
78 | self.optimizer.step()
79 |
80 | @property
81 | def config(self):
82 | return {
83 | 'dict_class': 'JumpReluAutoEncoder',
84 | 'trainer_class': 'JumpReluTrainer',
85 | 'activation_dim': self.ae.activation_dim,
86 | 'dict_size': self.ae.dict_size,
87 | 'lr': self.lr,
88 | 'warmup_steps': self.warmup_steps,
89 | 'l0_penalty': self.l0_penalty,
90 | 'device': self.device,
91 | 'layer': self.layer,
92 | 'lm_name': self.lm_name,
93 | 'submodule_name': self.submodule_name,
94 | 'wandb_name': self.wandb_name,
95 | }
--------------------------------------------------------------------------------
/dictionary_learning/trainers/standard.py:
--------------------------------------------------------------------------------
1 | """
2 | Implements the standard SAE training scheme.
3 | """
4 | import torch as t
5 | from ..trainers.trainer import SAETrainer
6 | from ..config import DEBUG
7 | from ..dictionary import AutoEncoder
8 | from collections import namedtuple
9 |
10 | class ConstrainedAdam(t.optim.Adam):
11 | """
12 | A variant of Adam where some of the parameters are constrained to have unit norm.
13 | """
14 | def __init__(self, params, constrained_params, lr):
15 | super().__init__(params, lr=lr)
16 | self.constrained_params = list(constrained_params)
17 |
18 | def step(self, closure=None):
19 | with t.no_grad():
20 | for p in self.constrained_params:
21 | normed_p = p / p.norm(dim=0, keepdim=True)
22 | # project away the parallel component of the gradient
23 | p.grad -= (p.grad * normed_p).sum(dim=0, keepdim=True) * normed_p
24 | super().step(closure=closure)
25 | with t.no_grad():
26 | for p in self.constrained_params:
27 | # renormalize the constrained parameters
28 | p /= p.norm(dim=0, keepdim=True)
29 |
30 | class StandardTrainer(SAETrainer):
31 | """
32 | Standard SAE training scheme.
33 | """
34 | def __init__(self,
35 | dict_class=AutoEncoder,
36 | activation_dim=512,
37 | dict_size=64*512,
38 | lr=1e-3,
39 | l1_penalty=1e-1,
40 | warmup_steps=1000, # lr warmup period at start of training and after each resample
41 | resample_steps=None, # how often to resample neurons
42 | seed=None,
43 | device=None,
44 | layer=None,
45 | lm_name=None,
46 | wandb_name='StandardTrainer',
47 | ):
48 | super().__init__(seed)
49 |
50 | assert layer is not None and lm_name is not None
51 | self.layer = layer
52 | self.lm_name = lm_name
53 |
54 | if seed is not None:
55 | t.manual_seed(seed)
56 | t.cuda.manual_seed_all(seed)
57 |
58 | # initialize dictionary
59 | self.ae = dict_class(activation_dim, dict_size)
60 |
61 | self.lr = lr
62 | self.l1_penalty=l1_penalty
63 | self.warmup_steps = warmup_steps
64 | self.wandb_name = wandb_name
65 |
66 | if device is None:
67 | self.device = 'cuda' if t.cuda.is_available() else 'cpu'
68 | else:
69 | self.device = device
70 | self.ae.to(self.device)
71 |
72 | self.resample_steps = resample_steps
73 |
74 |
75 | if self.resample_steps is not None:
76 | # how many steps since each neuron was last activated?
77 | self.steps_since_active = t.zeros(self.ae.dict_size, dtype=int).to(self.device)
78 | else:
79 | self.steps_since_active = None
80 |
81 | self.optimizer = ConstrainedAdam(self.ae.parameters(), self.ae.decoder.parameters(), lr=lr)
82 | if resample_steps is None:
83 | def warmup_fn(step):
84 | return min(step / warmup_steps, 1.)
85 | else:
86 | def warmup_fn(step):
87 | return min((step % resample_steps) / warmup_steps, 1.)
88 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=warmup_fn)
89 |
90 | def resample_neurons(self, deads, activations):
91 | with t.no_grad():
92 | if deads.sum() == 0: return
93 | print(f"resampling {deads.sum().item()} neurons")
94 |
95 | # compute loss for each activation
96 | losses = (activations - self.ae(activations)).norm(dim=-1)
97 |
98 | # sample input to create encoder/decoder weights from
99 | n_resample = min([deads.sum(), losses.shape[0]])
100 | indices = t.multinomial(losses, num_samples=n_resample, replacement=False)
101 | sampled_vecs = activations[indices]
102 |
103 | # get norm of the living neurons
104 | alive_norm = self.ae.encoder.weight[~deads].norm(dim=-1).mean()
105 |
106 | # resample first n_resample dead neurons
107 | deads[deads.nonzero()[n_resample:]] = False
108 | self.ae.encoder.weight[deads] = sampled_vecs * alive_norm * 0.2
109 | self.ae.decoder.weight[:,deads] = (sampled_vecs / sampled_vecs.norm(dim=-1, keepdim=True)).T
110 | self.ae.encoder.bias[deads] = 0.
111 |
112 |
113 | # reset Adam parameters for dead neurons
114 | state_dict = self.optimizer.state_dict()['state']
115 | ## encoder weight
116 | state_dict[1]['exp_avg'][deads] = 0.
117 | state_dict[1]['exp_avg_sq'][deads] = 0.
118 | ## encoder bias
119 | state_dict[2]['exp_avg'][deads] = 0.
120 | state_dict[2]['exp_avg_sq'][deads] = 0.
121 | ## decoder weight
122 | state_dict[3]['exp_avg'][:,deads] = 0.
123 | state_dict[3]['exp_avg_sq'][:,deads] = 0.
124 |
125 | def loss(self, x, logging=False, **kwargs):
126 | x_hat, f = self.ae(x, output_features=True)
127 | l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean()
128 | l1_loss = f.norm(p=1, dim=-1).mean()
129 |
130 | if self.steps_since_active is not None:
131 | # update steps_since_active
132 | deads = (f == 0).all(dim=0)
133 | self.steps_since_active[deads] += 1
134 | self.steps_since_active[~deads] = 0
135 |
136 | loss = l2_loss + self.l1_penalty * l1_loss
137 |
138 | if not logging:
139 | return loss
140 | else:
141 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])(
142 | x, x_hat, f,
143 | {
144 | 'l2_loss' : l2_loss.item(),
145 | 'mse_loss' : (x - x_hat).pow(2).sum(dim=-1).mean().item(),
146 | 'sparsity_loss' : l1_loss.item(),
147 | 'loss' : loss.item()
148 | }
149 | )
150 |
151 |
152 | def update(self, step, activations):
153 | activations = activations.to(self.device)
154 |
155 | self.optimizer.zero_grad()
156 | loss = self.loss(activations)
157 | loss.backward()
158 | self.optimizer.step()
159 | self.scheduler.step()
160 |
161 | if self.resample_steps is not None and step % self.resample_steps == 0:
162 | self.resample_neurons(self.steps_since_active > self.resample_steps / 2, activations)
163 |
164 | @property
165 | def config(self):
166 | return {
167 | 'trainer_class' : 'StandardTrainer',
168 | 'lr' : self.lr,
169 | 'l1_penalty' : self.l1_penalty,
170 | 'warmup_steps' : self.warmup_steps,
171 | 'resample_steps' : self.resample_steps,
172 | 'device' : self.device,
173 | 'layer' : self.layer,
174 | 'lm_name' : self.lm_name,
175 | 'wandb_name': self.wandb_name,
176 | }
177 |
--------------------------------------------------------------------------------
/dictionary_learning/trainers/standard_new.py:
--------------------------------------------------------------------------------
1 | """
2 | Implements the SAE training scheme from https://transformer-circuits.pub/2024/april-update/index.html#training-saes
3 | """
4 | import torch as t
5 | from ..trainers.trainer import SAETrainer
6 | from ..config import DEBUG
7 | from ..dictionary import AutoEncoderNew
8 | from collections import namedtuple
9 |
10 | class StandardTrainerNew(SAETrainer):
11 | """
12 | Standard SAE training scheme.
13 | """
14 | def __init__(self,
15 | dict_class=AutoEncoderNew,
16 | activation_dim=512,
17 | dict_size=64*512,
18 | lr=5e-5,
19 | l1_penalty=1e-1,
20 | lambda_warm_steps=1500, # steps over which to warm up the l1 penalty
21 | decay_start=24000, # when does the lr decay start
22 | steps=30000, # when when does training end
23 | seed=None,
24 | device=None,
25 | wandb_name='StandardTrainerNew_Anthropic',
26 | ):
27 | super().__init__(seed)
28 |
29 | if seed is not None:
30 | t.manual_seed(seed)
31 | t.cuda.manual_seed_all(seed)
32 |
33 | # initialize dictionary
34 | self.ae = dict_class(activation_dim, dict_size)
35 |
36 | self.lr = lr
37 | self.l1_penalty=l1_penalty
38 | self.lambda_warm_steps=lambda_warm_steps
39 | self.decay_start=decay_start
40 | self.steps = steps
41 | self.wandb_name = wandb_name
42 |
43 | if device is None:
44 | self.device = 'cuda' if t.cuda.is_available() else 'cpu'
45 | else:
46 | self.device = device
47 | self.ae.to(self.device)
48 |
49 | self.optimizer = t.optim.Adam(self.ae.parameters(), lr=lr)
50 | def lr_fn(step):
51 | if step < decay_start:
52 | return 1.
53 | else:
54 | return (steps - step) / (steps - decay_start)
55 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn)
56 |
57 | def loss(self, x, step=None, logging=False):
58 | # NOTE: not using normalization
59 | # if logging:
60 | # x = x / x.norm(dim=-1).mean() * (self.ae.activation_dim ** 0.5)
61 |
62 | x_hat, f = self.ae(x, output_features=True)
63 |
64 | l1_penalty = self.l1_penalty
65 | l1_penalty = min(1., step / self.lambda_warm_steps) * self.l1_penalty
66 |
67 | L_recon = (x - x_hat).pow(2).sum(dim=-1).mean()
68 | L_sparse = f.norm(p=1, dim=-1).mean()
69 |
70 | loss = L_recon + l1_penalty * L_sparse
71 |
72 | if not logging:
73 | return loss
74 | else:
75 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])(
76 | x, x_hat, f,
77 | {
78 | 'mse_loss' : L_recon.item(),
79 | 'sparsity_loss' : L_sparse.item(),
80 | 'l1_penalty' : l1_penalty,
81 | 'loss' : loss.item()
82 | }
83 | )
84 |
85 |
86 | def update(self, step, x):
87 | x = x.to(self.device)
88 |
89 | # NOTE: not using normalization
90 | # normalization
91 | # x = x / x.norm(dim=-1).mean() * (self.ae.activation_dim ** 0.5)
92 |
93 | self.optimizer.zero_grad()
94 | loss = self.loss(x, step=step)
95 | loss.backward()
96 |
97 | # clip grad norm
98 | t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0)
99 |
100 | self.optimizer.step()
101 | self.scheduler.step()
102 |
103 | return loss.item()
104 |
105 | @property
106 | def config(self):
107 | return {
108 | 'trainer_class' : 'StandardTrainerNew',
109 | 'dict_class' : 'AutoEncoderNew',
110 | 'lr' : self.lr,
111 | 'l1_penalty' : self.l1_penalty,
112 | 'lambda_warm_steps' : self.lambda_warm_steps,
113 | 'decay_start' : self.decay_start,
114 | 'steps' : self.steps,
115 | 'seed' : self.seed,
116 | 'activation_dim' : self.ae.activation_dim,
117 | 'dict_size' : self.ae.dict_size,
118 | 'device' : self.device,
119 | 'wandb_name' : self.wandb_name,
120 | }
--------------------------------------------------------------------------------
/dictionary_learning/trainers/top_k.py:
--------------------------------------------------------------------------------
1 | """
2 | Implements the SAE training scheme from https://arxiv.org/abs/2406.04093.
3 | Significant portions of this code have been copied from https://github.com/EleutherAI/sae/blob/main/sae
4 | """
5 |
6 | import einops
7 | import torch as t
8 | import torch.nn as nn
9 | from collections import namedtuple
10 |
11 | from ..config import DEBUG
12 | from ..dictionary import Dictionary
13 | from ..kernels import TritonDecoder
14 | from ..trainers.trainer import SAETrainer
15 |
16 |
17 | @t.no_grad()
18 | def geometric_median(points: t.Tensor, max_iter: int = 100, tol: float = 1e-5):
19 | """Compute the geometric median `points`. Used for initializing decoder bias."""
20 | # Initialize our guess as the mean of the points
21 | guess = points.mean(dim=0)
22 | prev = t.zeros_like(guess)
23 |
24 | # Weights for iteratively reweighted least squares
25 | weights = t.ones(len(points), device=points.device)
26 |
27 | for _ in range(max_iter):
28 | prev = guess
29 |
30 | # Compute the weights
31 | weights = 1 / t.norm(points - guess, dim=1)
32 |
33 | # Normalize the weights
34 | weights /= weights.sum()
35 |
36 | # Compute the new geometric median
37 | guess = (weights.unsqueeze(1) * points).sum(dim=0)
38 |
39 | # Early stopping condition
40 | if t.norm(guess - prev) < tol:
41 | break
42 |
43 | return guess
44 |
45 |
46 | class AutoEncoderTopK(Dictionary, nn.Module):
47 | """
48 | The top-k autoencoder architecture and initialization used in https://arxiv.org/abs/2406.04093
49 | """
50 | def __init__(self, activation_dim, dict_size, k):
51 | super().__init__()
52 | self.activation_dim = activation_dim
53 | self.dict_size = dict_size
54 | self.k = k
55 |
56 | self.encoder = nn.Linear(activation_dim, dict_size)
57 | self.encoder.bias.data.zero_()
58 |
59 | self.decoder = nn.Parameter(self.encoder.weight.data.clone())
60 | self.set_decoder_norm_to_unit_norm()
61 |
62 | self.b_dec = self.b_dec = nn.Parameter(t.zeros(activation_dim))
63 |
64 | def encode(self, x):
65 | return nn.functional.relu(self.encoder(x - self.b_dec))
66 |
67 | def decode(self, top_acts, top_indices):
68 | d = TritonDecoder.apply(top_indices, top_acts, self.decoder.mT)
69 | return d + self.b_dec
70 |
71 | def forward(self, x, output_features=False):
72 | # (rangell): some shape hacking going on here
73 | f = self.encode(x.view(-1, x.shape[-1]))
74 | top_acts, top_indices = f.topk(self.k, sorted=False)
75 | x_hat = self.decode(top_acts, top_indices).view(x.shape)
76 | f = f.view(*x.shape[:-1], f.shape[-1])
77 | if not output_features:
78 | return x_hat
79 | elif output_features == "all":
80 | return x_hat, f, top_acts, top_indices
81 | else:
82 | return x_hat, f
83 |
84 | @t.no_grad()
85 | def set_decoder_norm_to_unit_norm(self):
86 | eps = t.finfo(self.decoder.dtype).eps
87 | norm = t.norm(self.decoder.data, dim=1, keepdim=True)
88 | self.decoder.data /= norm + eps
89 |
90 | @t.no_grad()
91 | def remove_gradient_parallel_to_decoder_directions(self):
92 | assert self.decoder.grad is not None # keep pyright happy
93 |
94 | parallel_component = einops.einsum(
95 | self.decoder.grad,
96 | self.decoder.data,
97 | "d_sae d_in, d_sae d_in -> d_sae",
98 | )
99 | self.decoder.grad -= einops.einsum(
100 | parallel_component,
101 | self.decoder.data,
102 | "d_sae, d_sae d_in -> d_sae d_in",
103 | )
104 |
105 | def from_pretrained(path, k=100, device=None):
106 | """
107 | Load a pretrained autoencoder from a file.
108 | """
109 | state_dict = t.load(path, map_location=device)
110 | dict_size, activation_dim = state_dict['encoder.weight'].shape
111 | autoencoder = AutoEncoderTopK(activation_dim, dict_size, k)
112 | autoencoder.load_state_dict(state_dict)
113 | if device is not None:
114 | autoencoder.to(device)
115 | return autoencoder
116 |
117 |
118 | class TrainerTopK(SAETrainer):
119 | """
120 | Top-K SAE training scheme.
121 | """
122 | def __init__(self,
123 | dict_class=AutoEncoderTopK,
124 | activation_dim=512,
125 | dict_size=64*512,
126 | k=100,
127 | auxk_alpha=1/32, # see Appendix A.2
128 | decay_start=24000, # when does the lr decay start
129 | steps=30000, # when when does training end
130 | seed=None,
131 | device=None,
132 | layer=None,
133 | lm_name=None,
134 | wandb_name='AutoEncoderTopK',
135 | submodule_name=None,
136 | ):
137 | super().__init__(seed)
138 |
139 | assert layer is not None and lm_name is not None
140 | self.layer = layer
141 | self.lm_name = lm_name
142 | self.submodule_name = submodule_name
143 |
144 | self.wandb_name = wandb_name
145 | self.steps = steps
146 | self.k = k
147 | if seed is not None:
148 | t.manual_seed(seed)
149 | t.cuda.manual_seed_all(seed)
150 |
151 | # Initialise autoencoder
152 | self.ae = dict_class(activation_dim, dict_size, k)
153 | if device is None:
154 | self.device = 'cuda' if t.cuda.is_available() else 'cpu'
155 | else:
156 | self.device = device
157 | self.ae.to(self.device)
158 |
159 | # Auto-select LR using 1 / sqrt(d) scaling law from Figure 3 of the paper
160 | scale = dict_size / (2 ** 14)
161 | self.lr = 2e-4 / scale ** 0.5
162 | self.auxk_alpha = auxk_alpha
163 | self.dead_feature_threshold = 10_000_000
164 |
165 | # Optimizer and scheduler
166 | self.optimizer = t.optim.Adam(self.ae.parameters(), lr=self.lr, betas=(0.9, 0.999))
167 | def lr_fn(step):
168 | if step < decay_start:
169 | return 1.
170 | else:
171 | return (steps - step) / (steps - decay_start)
172 | self.scheduler = t.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_fn)
173 |
174 | # Training parameters
175 | self.num_tokens_since_fired = t.zeros(dict_size, dtype=t.long, device=device)
176 |
177 | # Log the effective L0, i.e. number of features actually used, which should a constant value (K)
178 | # Note: The standard L0 is essentially a measure of dead features for Top-K SAEs)
179 | self.logging_parameters = ["effective_l0", "dead_features"]
180 | self.effective_l0 = -1
181 | self.dead_features = -1
182 |
183 | def loss(self, x, step=None, logging=False):
184 |
185 | x = x.to(self.device)
186 |
187 | # Run the SAE
188 | f = self.ae.encode(x)
189 | top_acts, top_indices = f.topk(self.k, sorted=False)
190 | x_hat = self.ae.decode(top_acts, top_indices)
191 |
192 | # Measure goodness of reconstruction
193 | e = x_hat - x
194 | total_variance = (x - x.mean(0)).pow(2).sum(0)
195 |
196 | # Update the effective L0 (again, should just be K)
197 | self.effective_l0 = top_acts.size(1)
198 |
199 | # Update "number of tokens since fired" for each features
200 | num_tokens_in_step = x.size(0)
201 | did_fire = t.zeros_like(self.num_tokens_since_fired, dtype=t.bool)
202 | did_fire[top_indices.flatten()] = True
203 | self.num_tokens_since_fired += num_tokens_in_step
204 | self.num_tokens_since_fired[did_fire] = 0
205 |
206 | # Compute dead feature mask based on "number of tokens since fired"
207 | dead_mask = (
208 | self.num_tokens_since_fired > self.dead_feature_threshold
209 | if self.auxk_alpha > 0
210 | else None
211 | ).to(f.device)
212 | self.dead_features = int(dead_mask.sum())
213 |
214 | # If dead features: Second decoder pass for AuxK loss
215 | if dead_mask is not None and (num_dead := int(dead_mask.sum())) > 0:
216 |
217 | # Heuristic from Appendix B.1 in the paper
218 | k_aux = x.shape[-1] // 2
219 |
220 | # Reduce the scale of the loss if there are a small number of dead latents
221 | scale = min(num_dead / k_aux, 1.0)
222 | k_aux = min(k_aux, num_dead)
223 |
224 | # Don't include living latents in this loss
225 | auxk_latents = t.where(dead_mask[None], f, -t.inf)
226 |
227 | # Top-k dead latents
228 | auxk_acts, auxk_indices = auxk_latents.topk(k_aux, sorted=False)
229 |
230 | # Encourage the top ~50% of dead latents to predict the residual of the
231 | # top k living latents
232 | e_hat = self.ae.decode(auxk_acts, auxk_indices)
233 | auxk_loss = (e_hat - e).pow(2) #.sum(0)
234 | auxk_loss = scale * t.mean(auxk_loss / total_variance)
235 | else:
236 | auxk_loss = x_hat.new_tensor(0.0)
237 |
238 | l2_loss = e.pow(2).sum(dim=-1).mean()
239 | auxk_loss = auxk_loss.sum(dim=-1).mean()
240 | loss = l2_loss + self.auxk_alpha * auxk_loss
241 |
242 | if not logging:
243 | return loss
244 | else:
245 | return namedtuple('LossLog', ['x', 'x_hat', 'f', 'losses'])(
246 | x, x_hat, f,
247 | {
248 | 'l2_loss': l2_loss.item(),
249 | 'auxk_loss': auxk_loss.item(),
250 | 'loss' : loss.item()
251 | }
252 | )
253 |
254 | def update(self, step, x):
255 |
256 | x = x.to(self.device)
257 |
258 | # Initialise the decoder bias
259 | if step == 0:
260 | median = geometric_median(x)
261 | self.ae.b_dec.data = median
262 |
263 | # Make sure the decoder is still unit-norm
264 | self.ae.set_decoder_norm_to_unit_norm()
265 |
266 | # compute the loss
267 | x = x.to(self.device)
268 | loss = self.loss(x, step=step)
269 | loss.backward()
270 |
271 | # clip grad norm and remove grads parallel to decoder directions
272 | t.nn.utils.clip_grad_norm_(self.ae.parameters(), 1.0)
273 | self.ae.remove_gradient_parallel_to_decoder_directions()
274 |
275 | # do a training step
276 | self.optimizer.step()
277 | self.optimizer.zero_grad()
278 | self.scheduler.step()
279 | return loss.item()
280 |
281 | @property
282 | def config(self):
283 | return {
284 | 'trainer_class' : 'TrainerTopK',
285 | 'dict_class' : 'AutoEncoderTopK',
286 | 'lr' : self.lr,
287 | 'steps' : self.steps,
288 | 'seed' : self.seed,
289 | 'activation_dim' : self.ae.activation_dim,
290 | 'dict_size' : self.ae.dict_size,
291 | 'k': self.ae.k,
292 | 'device' : self.device,
293 | "layer" : self.layer,
294 | 'lm_name' : self.lm_name,
295 | 'wandb_name' : self.wandb_name,
296 | 'submodule_name' : self.submodule_name,
297 | }
298 |
--------------------------------------------------------------------------------
/dictionary_learning/trainers/trainer.py:
--------------------------------------------------------------------------------
1 | class SAETrainer:
2 | """
3 | Generic class for implementing SAE training algorithms
4 | """
5 | def __init__(self, seed=None):
6 | self.seed = seed
7 | self.logging_parameters = []
8 |
9 | def update(self,
10 | step, # index of step in training
11 | activations, # of shape [batch_size, d_submodule]
12 | ):
13 | pass # implemented by subclasses
14 |
15 | def get_logging_parameters(self):
16 | stats = {}
17 | for param in self.logging_parameters:
18 | if hasattr(self, param):
19 | stats[param] = getattr(self, param)
20 | else:
21 | print(f"Warning: {param} not found in {self}")
22 | return stats
23 |
24 | @property
25 | def config(self):
26 | return {
27 | 'wandb_name': 'trainer',
28 | }
29 |
--------------------------------------------------------------------------------
/dictionary_learning/training.py:
--------------------------------------------------------------------------------
1 | """
2 | Training dictionaries
3 | """
4 |
5 | import torch as t
6 | from .dictionary import AutoEncoder
7 | import os
8 | from tqdm import tqdm
9 | from .trainers.standard import StandardTrainer
10 | import wandb
11 | import json
12 | from .utils import cfg_filename
13 | # from .evaluation import evaluate
14 |
15 | def trainSAE(
16 | data,
17 | trainer_configs = [
18 | {
19 | 'trainer' : StandardTrainer,
20 | 'dict_class' : AutoEncoder,
21 | 'activation_dim' : 512,
22 | 'dictionary_size' : 64*512,
23 | 'lr' : 1e-3,
24 | 'l1_penalty' : 1e-1,
25 | 'warmup_steps' : 1000,
26 | 'resample_steps' : None,
27 | 'seed' : None,
28 | 'wandb_name' : 'StandardTrainer',
29 | }
30 | ],
31 | steps=None,
32 | save_steps=None,
33 | save_dir=None, # use {run} to refer to wandb run
34 | log_steps=None,
35 | activations_split_by_head=False, # set to true if data is shape [batch, pos, num_head, head_dim/resid_dim]
36 | transcoder=False,
37 | ):
38 | """
39 | Train SAEs using the given trainers
40 | """
41 |
42 | trainers = []
43 | for config in trainer_configs:
44 | trainer = config['trainer']
45 | del config['trainer']
46 | trainers.append(
47 | trainer(
48 | **config
49 | )
50 | )
51 |
52 | if log_steps is not None:
53 | '''
54 | wandb.init(
55 | entity="sae-training",
56 | project="sae-training",
57 | config={f'{trainer.config["wandb_name"]}-{i}' : trainer.config for i, trainer in enumerate(trainers)}
58 | )
59 | '''
60 | # process save_dir in light of run name
61 | if save_dir is not None:
62 | save_dir = save_dir.format(run=wandb.run.name)
63 |
64 | # make save dirs, export config
65 | if save_dir is not None:
66 | save_dirs = [os.path.join(save_dir, f"{cfg_filename(trainer_config)}") for trainer_config in trainer_configs]
67 | for trainer, dir in zip(trainers, save_dirs):
68 | os.makedirs(dir, exist_ok=True)
69 | # save config
70 | config = {'trainer' : trainer.config}
71 | try:
72 | config['buffer'] = data.config
73 | except: pass
74 | with open(os.path.join(dir, "config.json"), 'w') as f:
75 | json.dump(config, f, indent=4)
76 | else:
77 | save_dirs = [None for _ in trainer_configs]
78 |
79 | for step, act in enumerate(tqdm(data, total=steps)):
80 | if steps is not None and step >= steps:
81 | break
82 |
83 | # logging
84 | if log_steps is not None and step % log_steps == 0:
85 | log = {}
86 | with t.no_grad():
87 |
88 | # quick hack to make sure all trainers get the same x
89 | # TODO make this less hacky
90 | z = act.clone()
91 | for i, trainer in enumerate(trainers):
92 | act = z.clone()
93 | if activations_split_by_head: # x.shape: [batch, pos, n_heads, d_head]
94 | act = act[..., i, :]
95 | trainer_name = f'{trainer.config["wandb_name"]}-{i}'
96 | if not transcoder:
97 | act, act_hat, f, losslog = trainer.loss(act, step=step, logging=True) # act is x
98 |
99 | # L0
100 | l0 = (f != 0).float().sum(dim=-1).mean().item()
101 | # fraction of variance explained
102 | total_variance = t.var(act, dim=0).sum()
103 | residual_variance = t.var(act - act_hat, dim=0).sum()
104 | frac_variance_explained = (1 - residual_variance / total_variance)
105 | log[f'{trainer_name}/frac_variance_explained'] = frac_variance_explained.item()
106 | else: # transcoder
107 | x, x_hat, f, losslog = trainer.loss(act, step=step, logging=True) # act is x, y
108 |
109 | # L0
110 | l0 = (f != 0).float().sum(dim=-1).mean().item()
111 |
112 | # fraction of variance explained
113 | # TODO: adapt for transcoder
114 | # total_variance = t.var(x, dim=0).sum()
115 | # residual_variance = t.var(x - x_hat, dim=0).sum()
116 | # frac_variance_explained = (1 - residual_variance / total_variance)
117 | # log[f'{trainer_name}/frac_variance_explained'] = frac_variance_explained.item()
118 |
119 | # log parameters from training
120 | log.update({f'{trainer_name}/{k}' : v for k, v in losslog.items()})
121 | log[f'{trainer_name}/l0'] = l0
122 | trainer_log = trainer.get_logging_parameters()
123 | for name, value in trainer_log.items():
124 | log[f'{trainer_name}/{name}'] = value
125 |
126 | # TODO get this to work
127 | # metrics = evaluate(
128 | # trainer.ae,
129 | # data,
130 | # device=trainer.device
131 | # )
132 | # log.update(
133 | # {f'trainer{i}/{k}' : v for k, v in metrics.items()}
134 | # )
135 | wandb.log(log, step=step)
136 |
137 | # saving
138 | if save_steps is not None and step % save_steps == 0:
139 | for dir, trainer in zip(save_dirs, trainers):
140 | if dir is not None:
141 | if not os.path.exists(os.path.join(dir, "checkpoints")):
142 | os.mkdir(os.path.join(dir, "checkpoints"))
143 | t.save(
144 | trainer.ae.state_dict(),
145 | os.path.join(dir, "checkpoints", f"ae_{step}.pt")
146 | )
147 |
148 | # training
149 | for trainer in trainers:
150 | trainer.update(step, act)
151 |
152 | # save final SAEs
153 | for save_dir, trainer in zip(save_dirs, trainers):
154 | if save_dir is not None:
155 | t.save(trainer.ae.state_dict(), os.path.join(save_dir, "ae.pt"))
156 |
157 | # End the wandb run
158 | '''
159 | if log_steps is not None:
160 | wandb.finish()
161 | '''
--------------------------------------------------------------------------------
/dictionary_learning/utils.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import zstandard as zstd
3 | import io
4 | import json
5 | import argparse
6 |
7 | def hf_dataset_to_generator(dataset_name, split='train', streaming=True):
8 | dataset = load_dataset(dataset_name, split=split, streaming=streaming)
9 |
10 | def gen():
11 | for x in iter(dataset):
12 | yield x['text']
13 |
14 | return gen()
15 |
16 | def zst_to_generator(data_path):
17 | """
18 | Load a dataset from a .jsonl.zst file.
19 | The jsonl entries is assumed to have a 'text' field
20 | """
21 | compressed_file = open(data_path, 'rb')
22 | dctx = zstd.ZstdDecompressor()
23 | reader = dctx.stream_reader(compressed_file)
24 | text_stream = io.TextIOWrapper(reader, encoding='utf-8')
25 | def generator():
26 | for line in text_stream:
27 | yield json.loads(line)['text']
28 | return generator()
29 |
30 | def cfg_filename(cfg):
31 | result = []
32 | for key in cfg:
33 | value = str(cfg[key])
34 | value = value.replace("/", "")
35 | result.append(f"{key}:{value[-20:]}")
36 | return '_'.join(result)
37 |
38 | def str2bool(v):
39 | if isinstance(v, bool):
40 | return v
41 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
42 | return True
43 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
44 | return False
45 | else:
46 | raise argparse.ArgumentTypeError('Boolean value expected.')
47 |
--------------------------------------------------------------------------------
/results/1on/l0_deltace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/1on/l0_deltace.png
--------------------------------------------------------------------------------
/results/1on/l0_mse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/1on/l0_mse.png
--------------------------------------------------------------------------------
/results/1on/l0_recovered.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/1on/l0_recovered.png
--------------------------------------------------------------------------------
/results/1on_lb/1on_lb.csv:
--------------------------------------------------------------------------------
1 | "Name","State","Notes","User","Tags","Created","Runtime","Sweep","SwitchAutoEncoder-0.activation_dim","SwitchAutoEncoder-0.auxk_alpha","SwitchAutoEncoder-0.decay_start","SwitchAutoEncoder-0.device","SwitchAutoEncoder-0.dict_class","SwitchAutoEncoder-0.dict_size","SwitchAutoEncoder-0.experts","SwitchAutoEncoder-0.heaviside","SwitchAutoEncoder-0.k","SwitchAutoEncoder-0.layer","SwitchAutoEncoder-0.lb_alpha","SwitchAutoEncoder-0.lm_name","SwitchAutoEncoder-0.seed","SwitchAutoEncoder-0.steps","SwitchAutoEncoder-0.trainer","SwitchAutoEncoder-0.wandb_name","SwitchAutoEncoder-1.activation_dim","SwitchAutoEncoder-1.auxk_alpha","SwitchAutoEncoder-1.decay_start","SwitchAutoEncoder-1.device","SwitchAutoEncoder-1.dict_class","SwitchAutoEncoder-1.dict_size","SwitchAutoEncoder-1.experts","SwitchAutoEncoder-1.heaviside","SwitchAutoEncoder-1.k","SwitchAutoEncoder-1.layer","SwitchAutoEncoder-1.lb_alpha","SwitchAutoEncoder-1.lm_name","SwitchAutoEncoder-1.seed","SwitchAutoEncoder-1.steps","SwitchAutoEncoder-1.trainer","SwitchAutoEncoder-1.wandb_name","SwitchAutoEncoder-0/auxk_loss","SwitchAutoEncoder-0/cossim","SwitchAutoEncoder-0/dead_features","SwitchAutoEncoder-0/effective_l0","SwitchAutoEncoder-0/frac_alive","SwitchAutoEncoder-0/frac_recovered","SwitchAutoEncoder-0/frac_variance_explained","SwitchAutoEncoder-0/l0","SwitchAutoEncoder-0/l1_loss","SwitchAutoEncoder-0/l2_loss","SwitchAutoEncoder-0/l2_ratio","SwitchAutoEncoder-0/lb_loss","SwitchAutoEncoder-0/loss","SwitchAutoEncoder-0/loss_original","SwitchAutoEncoder-0/loss_reconstructed","SwitchAutoEncoder-0/loss_zero","SwitchAutoEncoder-0/mse_loss","SwitchAutoEncoder-1/auxk_loss","SwitchAutoEncoder-1/cossim","SwitchAutoEncoder-1/dead_features","SwitchAutoEncoder-1/effective_l0","SwitchAutoEncoder-1/frac_alive","SwitchAutoEncoder-1/frac_recovered","SwitchAutoEncoder-1/frac_variance_explained","SwitchAutoEncoder-1/l0","SwitchAutoEncoder-1/l1_loss","SwitchAutoEncoder-1/l2_loss","SwitchAutoEncoder-1/l2_ratio","SwitchAutoEncoder-1/lb_loss","SwitchAutoEncoder-1/loss","SwitchAutoEncoder-1/loss_original","SwitchAutoEncoder-1/loss_reconstructed","SwitchAutoEncoder-1/loss_zero","SwitchAutoEncoder-1/mse_loss"
2 | "clean-frog-6","finished","-","","","2024-07-30T01:01:23.000Z","47755","","768","0.03125","80000","cuda:7","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","100","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:7","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","300","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","0.000044464417442213744","0.9567948579788208","3228","64","0.00004069010537932627","0.9899426102638244","0.9863746166229248","253.58837890625","654.6430053710938","34.097957611083984","0.9548745155334472","1.0313754081726074","80446.0546875","3.3303747177124023","3.471522569656372","17.364582061767578","1241.935791015625","0.00004988056025467813","0.9448071718215942","8491","64","0.00004069010537932627","0.9861976504325868","0.9729151725769044","408.0152587890625","714.1879272460938","38.533409118652344","0.942345142364502","1.0331422090530396","239588.09375","3.26992130279541","3.467707872390747","17.59981346130371","1566.963623046875"
3 | "fanciful-oath-4","finished","-","","","2024-07-30T01:01:23.000Z","46459","","768","0.03125","80000","cuda:6","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","10","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:6","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","30","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","0.000005703236183762783","0.9625418186187744","52","64","0.00004069010537932627","0.9920895099639891","0.9881578087806702","195.4332275390625","607.35693359375","31.696577072143555","0.9622729420661926","1.0371432304382324","9045.2392578125","3.3303747177124023","3.441391944885254","17.364582061767578","1079.396484375","0.000008642778993817046","0.9617507457733154","79","64","0.00004069010537932627","0.9917919039726256","0.9810457825660706","192.5859375","579.9859008789062","32.029884338378906","0.9615445733070374","1.0366272926330566","24975.3359375","3.26992130279541","3.387542963027954","17.59981346130371","1096.583740234375"
4 | "wild-valley-4","finished","-","","","2024-07-30T01:01:23.000Z","45651","","768","0.03125","80000","cuda:5","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","1","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:5","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","3","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","7.480709314222622e-7","0.9632782936096193","7","64","0.00004069010537932627","0.9921401739120485","0.988391637802124","197.5400390625","603.214599609375","31.353612899780273","0.96281236410141","1.0378597974777222","1851.8035888671875","3.3303747177124023","3.4406819343566895","17.364582061767578","1058.0706787109375","0.000002888723429350648","0.9625215530395508","27","64","0.00004069010537932627","0.9922477006912231","0.9814417958259584","198.2266845703125","574.055419921875","31.630542755126953","0.9623088836669922","1.0359835624694824","3454.05517578125","3.26992130279541","3.3810112476348877","17.59981346130371","1073.6748046875"
5 | "different-feather-3","finished","-","","","2024-07-30T00:57:52.000Z","46911","","768","0.03125","80000","cuda:7","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","0.1","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:7","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","0.3","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","0.00004044335219077766","0.963446080684662","364","64","0.00004069010537932627","0.9921776056289672","0.9884382486343384","199.870361328125","601.8258056640625","31.283111572265625","0.9632078409194946","1.1113539934158323","1141.703369140625","3.3303747177124023","3.4401562213897705","17.364582061767578","1053.83154296875","0.000005634989520331146","0.9629814624786376","52","64","0.00004069010537932627","0.992292046546936","0.9816661477088928","199.752685546875","573.129638671875","31.42491912841797","0.962762176990509","1.045872926712036","1295.9554443359375","3.26992130279541","3.380375623703003","17.59981346130371","1060.68994140625"
6 | "feasible-plasma-1","finished","-","","","2024-07-30T00:57:40.000Z","47790","","768","0.03125","80000","cuda:6","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","0.01","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:6","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","0.03","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","0.000043194224417675287","0.9620879292488098","1624","64","0.00004069010537932627","0.991884469985962","0.988021969795227","203.5709228515625","604.99609375","31.848587036132812","0.9618685245513916","1.6018588542938232","1101.450927734375","3.3303747177124023","3.444269895553589","17.364582061767578","1091.7822265625","0.000042685824155341834","0.9621524214744568","1215","64","0.00004069010537932627","0.9919402003288268","0.9812399744987488","202.982666015625","575.129638671875","31.77358627319336","0.962187886238098","1.3047212362289429","1103.21728515625","3.26992130279541","3.3854167461395264","17.59981346130371","1085.34228515625"
7 | "summer-surf-2","finished","-","","","2024-07-30T00:57:40.000Z","47371","","768","0.03125","80000","cuda:5","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","0.001","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:5","dictionary_learning.trainers.switch1on.SwitchAutoEncoder","24576","32","false","64","8","0.003","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch1on.SwitchTrainer","SwitchAutoEncoder","0.000043277785152895376","0.9617789387702942","2094","64","0.00004069010537932627","0.9917656779289246","0.9879220128059388","204.457763671875","608.1683349609375","31.98760986328125","0.9616813659667968","1.7181545495986938","1099.705322265625","3.3303747177124023","3.445937156677246","17.364582061767578","1100.883056640625","0.00004343598266132176","0.9613968133926392","1650","64","0.00004069010537932627","0.9916406869888306","0.980884611606598","205.2655029296875","578.4508056640625","32.08662796020508","0.961195707321167","1.6460076570510864","1097.1473388671875","3.26992130279541","3.38970947265625","17.59981346130371","1105.91845703125"
--------------------------------------------------------------------------------
/results/1on_lb/alpha_deltace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/1on_lb/alpha_deltace.png
--------------------------------------------------------------------------------
/results/1on_lb/alpha_lossrec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/1on_lb/alpha_lossrec.png
--------------------------------------------------------------------------------
/results/1on_lb/alpha_mse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/1on_lb/alpha_mse.png
--------------------------------------------------------------------------------
/results/efficiency/efficiency.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/efficiency/efficiency.png
--------------------------------------------------------------------------------
/results/gated_lr/gated_l0_lossrec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/gated_lr/gated_l0_lossrec.png
--------------------------------------------------------------------------------
/results/gated_lr/gated_l0_mse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/gated_lr/gated_l0_mse.png
--------------------------------------------------------------------------------
/results/gated_lr/gated_mse_deltace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/gated_lr/gated_mse_deltace.png
--------------------------------------------------------------------------------
/results/gated_lr/primary-gated-1e-3.csv:
--------------------------------------------------------------------------------
1 | "Name","State","Notes","User","Tags","Created","Runtime","Sweep","GatedSAETrainer-0.activation_dim","GatedSAETrainer-0.device","GatedSAETrainer-0.dict_class","GatedSAETrainer-0.dict_size","GatedSAETrainer-0.l1_penalty","GatedSAETrainer-0.layer","GatedSAETrainer-0.lm_name","GatedSAETrainer-0.lr","GatedSAETrainer-0.seed","GatedSAETrainer-0.trainer","GatedSAETrainer-0.wandb_name","GatedSAETrainer-0.warmup_steps","GatedSAETrainer-1.activation_dim","GatedSAETrainer-1.device","GatedSAETrainer-1.dict_class","GatedSAETrainer-1.dict_size","GatedSAETrainer-1.l1_penalty","GatedSAETrainer-1.layer","GatedSAETrainer-1.lm_name","GatedSAETrainer-1.lr","GatedSAETrainer-1.seed","GatedSAETrainer-1.trainer","GatedSAETrainer-1.wandb_name","GatedSAETrainer-1.warmup_steps","GatedSAETrainer-0/aux_loss","GatedSAETrainer-0/cossim","GatedSAETrainer-0/frac_alive","GatedSAETrainer-0/frac_recovered","GatedSAETrainer-0/frac_variance_explained","GatedSAETrainer-0/l0","GatedSAETrainer-0/l1_loss","GatedSAETrainer-0/l2_loss","GatedSAETrainer-0/l2_ratio","GatedSAETrainer-0/loss","GatedSAETrainer-0/loss_original","GatedSAETrainer-0/loss_reconstructed","GatedSAETrainer-0/loss_zero","GatedSAETrainer-0/mse_loss","GatedSAETrainer-0/sparsity_loss","GatedSAETrainer-1/aux_loss","GatedSAETrainer-1/cossim","GatedSAETrainer-1/frac_alive","GatedSAETrainer-1/frac_recovered","GatedSAETrainer-1/frac_variance_explained","GatedSAETrainer-1/l0","GatedSAETrainer-1/l1_loss","GatedSAETrainer-1/l2_loss","GatedSAETrainer-1/l2_ratio","GatedSAETrainer-1/loss","GatedSAETrainer-1/loss_original","GatedSAETrainer-1/loss_reconstructed","GatedSAETrainer-1/loss_zero","GatedSAETrainer-1/mse_loss","GatedSAETrainer-1/sparsity_loss"
2 | "sunny-planet-5","finished","-","","","2024-07-16T17:57:45.000Z","64564","","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","20.56","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","30","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","3218.92138671875","0.920775294303894","0.00004069010537932627","0.9675883054733276","0.9754146933555604","14.7557373046875","277.45941162109375","46.65650939941406","0.9305692911148072","9132.3642578125","3.3303747177124023","3.7852470874786377","17.364582061767578","2245.533203125","175.97865295410156","4237.42578125","0.8820405602455139","0.00004069010537932627","0.9322636723518372","0.943988561630249","7.1939697265625","159.96339416503906","55.96257400512695","0.882667064666748","10325.232421875","3.26992130279541","4.240576267242432","17.59981346130371","3240.766845703125","95.46398162841795"
3 | "clear-frog-4","finished","-","","","2024-07-16T17:57:39.000Z","65896","","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","9.65","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","14.09","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1701.8465576171875","0.9534698724746704","0.00004069010537932627","0.9877625703811646","0.9854307770729064","38.939208984375","344.64532470703125","36.01933670043945","0.9426600933074952","5541.484375","3.3303747177124023","3.5021169185638428","17.364582061767578","1338.38232421875","262.4046325683594","2458.522705078125","0.94084894657135","0.00004069010537932627","0.981767475605011","0.9708531498908995","25.7823486328125","310.1355285644531","40.50748062133789","0.9534173607826232","7219.072265625","3.26992130279541","3.5311920642852783","17.59981346130371","1692.3714599609375","211.6930694580078"
4 | "drawn-shadow-2","finished","-","","","2024-07-16T17:56:25.000Z","66134","","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","4.53","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","6.62","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1053.4622802734375","0.9675292372703552","0.00004069010537932627","0.9922187924385072","0.9881061911582948","92.7789306640625","432.4425048828125","31.512962341308594","0.937780499458313","3639.833984375","3.3303747177124023","3.4395782947540283","17.364582061767578","1114.6705322265625","375.8903198242187","1358.39306640625","0.9616646766662598","0.00004069010537932627","0.9908108115196228","0.9810226559638976","62.830810546875","380.2125244140625","32.97096633911133","0.9485290050506592","4437.47705078125","3.26992130279541","3.401602268218994","17.59981346130371","1121.4132080078125","303.4264221191406"
5 | "hearty-music-1","finished","-","","","2024-07-16T17:56:25.000Z","65688","","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","2.13","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","3.11","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","899.0625610351562","0.9798789620399476","0.00004069010537932627","0.994871973991394","0.9936583042144777","273.841796875","704.4373168945312","24.840396881103516","0.9749097228050232","2705.451416015625","3.3303747177124023","3.4023427963256836","17.364582061767578","684.5126953125","537.3072509765625","1063.873046875","0.9749206304550172","0.00004069010537932627","0.9942847490310668","0.9875259399414062","173.9610595703125","553.797119140625","27.179105758666992","0.981850266456604","3166.013916015625","3.26992130279541","3.3518197536468506","17.59981346130371","779.90478515625","412.6763000488281"
6 | "smooth-aardvark-2","finished","-","","","2024-07-16T17:56:25.000Z","65729","","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","1","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","1.46","8","openai-community/gpt2","0.001","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","725.191650390625","0.9857865571975708","0.00004069010537932627","0.9952702522277832","0.9964048862457277","481.456787109375","998.9766235351562","21.06471061706543","0.99408757686615","2107.7939453125","3.3303747177124023","3.3967533111572266","17.364582061767578","518.01513671875","808.7833251953125","741.2411499023438","0.9823729991912842","0.00004069010537932627","0.9945960640907288","0.9929369688034058","380.1922607421875","852.5096435546875","22.97112274169922","0.9866905212402344","2290.761474609375","3.26992130279541","3.347358465194702","17.59981346130371","580.014404296875","680.2507934570312"
--------------------------------------------------------------------------------
/results/gated_lr/primary-gated-3e-4.csv:
--------------------------------------------------------------------------------
1 | "Name","State","Notes","User","Tags","Created","Runtime","Sweep","GatedSAETrainer-0.activation_dim","GatedSAETrainer-0.device","GatedSAETrainer-0.dict_class","GatedSAETrainer-0.dict_size","GatedSAETrainer-0.l1_penalty","GatedSAETrainer-0.layer","GatedSAETrainer-0.lm_name","GatedSAETrainer-0.lr","GatedSAETrainer-0.seed","GatedSAETrainer-0.trainer","GatedSAETrainer-0.wandb_name","GatedSAETrainer-0.warmup_steps","GatedSAETrainer-1.activation_dim","GatedSAETrainer-1.device","GatedSAETrainer-1.dict_class","GatedSAETrainer-1.dict_size","GatedSAETrainer-1.l1_penalty","GatedSAETrainer-1.layer","GatedSAETrainer-1.lm_name","GatedSAETrainer-1.lr","GatedSAETrainer-1.seed","GatedSAETrainer-1.trainer","GatedSAETrainer-1.wandb_name","GatedSAETrainer-1.warmup_steps","GatedSAETrainer-0/aux_loss","GatedSAETrainer-0/cossim","GatedSAETrainer-0/frac_alive","GatedSAETrainer-0/frac_recovered","GatedSAETrainer-0/frac_variance_explained","GatedSAETrainer-0/l0","GatedSAETrainer-0/l1_loss","GatedSAETrainer-0/l2_loss","GatedSAETrainer-0/l2_ratio","GatedSAETrainer-0/loss","GatedSAETrainer-0/loss_original","GatedSAETrainer-0/loss_reconstructed","GatedSAETrainer-0/loss_zero","GatedSAETrainer-0/mse_loss","GatedSAETrainer-0/sparsity_loss","GatedSAETrainer-1/aux_loss","GatedSAETrainer-1/cossim","GatedSAETrainer-1/frac_alive","GatedSAETrainer-1/frac_recovered","GatedSAETrainer-1/frac_variance_explained","GatedSAETrainer-1/l0","GatedSAETrainer-1/l1_loss","GatedSAETrainer-1/l2_loss","GatedSAETrainer-1/l2_ratio","GatedSAETrainer-1/loss","GatedSAETrainer-1/loss_original","GatedSAETrainer-1/loss_reconstructed","GatedSAETrainer-1/loss_zero","GatedSAETrainer-1/mse_loss","GatedSAETrainer-1/sparsity_loss"
2 | "rare-firefly-6","finished","-","","","2024-07-11T18:51:52.000Z","51591","","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","5","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","6","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1163.14306640625","0.973246932029724","0.00004069010537932627","0.9953534007072448","0.9913633465766908","127.900146484375","528.8916015625","27.75778579711914","0.9604530334472656","3927.76611328125","3.3303747177124023","3.3955864906311035","17.364582061767578","791.554931640625","394.5193481445313","1230.1353759765625","0.9700021743774414","0.00004069010537932627","0.9947134852409364","0.984986126422882","104.4046630859375","507.0380859375","29.08409881591797","0.9812511205673218","4367.89501953125","3.26992130279541","3.3456764221191406","17.59981346130371","870.0953979492188","376.71624755859375"
3 | "dashing-lion-5","finished","-","","","2024-07-11T05:25:59.000Z","50846","","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","20.56","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","30","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","3559.26806640625","0.9084912538528442","0.00004069010537932627","0.9579382538795472","0.9719541668891908","12.566650390625","378.3607177734375","49.76805114746094","0.895888090133667","11920.619140625","3.3303747177124023","3.92067813873291","17.364582061767578","2559.174560546875","280.8070068359375","4629.88525390625","0.8717569708824158","0.00004069010537932627","0.9255402088165284","0.9396619200706482","8.206787109375","312.36065673828125","58.12748718261719","0.8806691765785217","15425.427734375","3.26992130279541","4.336921691894531","17.59981346130371","3492.701171875","243.12893676757812"
4 | "graceful-totem-4","finished","-","","","2024-07-11T05:25:45.000Z","51864","","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","9.65","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","14.09","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1701.7183837890625","0.9564532041549684","0.00004069010537932627","0.9901224970817566","0.9862537980079652","48.8880615234375","440.683349609375","34.85649871826172","0.9686530232429504","6070.19482421875","3.3303747177124023","3.4689981937408447","17.364582061767578","1253.7666015625","322.1306457519531","2614.995361328125","0.934950828552246","0.00004069010537932627","0.9785899519920348","0.968268096446991","23.200927734375","387.74371337890625","42.21244812011719","0.9180249571800232","8781.263671875","3.26992130279541","3.5767250061035156","17.59981346130371","1837.4671630859375","305.5757751464844"
5 | "silver-puddle-3","finished","-","","","2024-07-11T05:25:35.000Z","51579","","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","4.53","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","6.62","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1059.000244140625","0.975863754749298","0.00004069010537932627","0.9960606098175048","0.9922483563423156","159.738525390625","578.526123046875","26.311534881591797","0.9801684617996216","3686.526611328125","3.3303747177124023","3.3856611251831055","17.364582061767578","710.6533203125","420.4597473144531","1368.194091796875","0.9670422077178956","0.00004069010537932627","0.99338698387146","0.9835944175720216","82.66162109375","464.7086181640625","30.388513565063477","0.9570202827453612","4679.21875","3.26992130279541","3.364684820175171","17.59981346130371","950.4469604492188","356.64434814453125"
6 | "generous-fog-2","finished","-","","","2024-07-11T05:25:20.000Z","51422","","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","2.13","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","3.11","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","603.53125","0.9911146759986876","0.00004069010537932627","0.998424232006073","0.9973034858703612","529.562744140625","1140.15185546875","16.387954711914062","0.993000864982605","2631.47314453125","3.3303747177124023","3.352489471435547","17.364582061767578","278.48040771484375","825.9996337890625","819.4901123046875","0.9827598333358764","0.00004069010537932627","0.9972550868988036","0.9914193153381348","289.6943359375","723.2354736328125","22.326417922973633","0.9679235219955444","3019.125","3.26992130279541","3.3092551231384277","17.59981346130371","508.3369140625","546.3135986328125"
7 | "pleasant-oath-1","finished","-","","","2024-07-11T05:21:07.000Z","52028","","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","1","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","1.46","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","245.25555419921875","0.9967520833015442","0.00004069010537932627","0.9992234110832214","0.9996508359909058","645.68798828125","1738.0128173828125","9.955805778503418","0.995638370513916","1791.7958984375","3.3303747177124023","3.341273546218872","17.364582061767578","102.31087493896484","1449.95556640625","388.83349609375","0.9960089921951294","0.00004069010537932627","0.9988710284233092","0.9988107085227966","565.810791015625","1553.6287841796875","10.871599197387695","0.9927268028259276","2247.65771484375","3.26992130279541","3.2860991954803467","17.59981346130371","119.66344451904295","1189.346923828125"
--------------------------------------------------------------------------------
/results/gated_lr/primary-gated-5e-5.csv:
--------------------------------------------------------------------------------
1 | "Name","State","Notes","User","Tags","Created","Runtime","Sweep","GatedSAETrainer-0.activation_dim","GatedSAETrainer-0.device","GatedSAETrainer-0.dict_class","GatedSAETrainer-0.dict_size","GatedSAETrainer-0.l1_penalty","GatedSAETrainer-0.layer","GatedSAETrainer-0.lm_name","GatedSAETrainer-0.lr","GatedSAETrainer-0.seed","GatedSAETrainer-0.trainer","GatedSAETrainer-0.wandb_name","GatedSAETrainer-0.warmup_steps","GatedSAETrainer-1.activation_dim","GatedSAETrainer-1.device","GatedSAETrainer-1.dict_class","GatedSAETrainer-1.dict_size","GatedSAETrainer-1.l1_penalty","GatedSAETrainer-1.layer","GatedSAETrainer-1.lm_name","GatedSAETrainer-1.lr","GatedSAETrainer-1.seed","GatedSAETrainer-1.trainer","GatedSAETrainer-1.wandb_name","GatedSAETrainer-1.warmup_steps","GatedSAETrainer-0/aux_loss","GatedSAETrainer-0/cossim","GatedSAETrainer-0/frac_alive","GatedSAETrainer-0/frac_recovered","GatedSAETrainer-0/frac_variance_explained","GatedSAETrainer-0/l0","GatedSAETrainer-0/l1_loss","GatedSAETrainer-0/l2_loss","GatedSAETrainer-0/l2_ratio","GatedSAETrainer-0/loss","GatedSAETrainer-0/loss_original","GatedSAETrainer-0/loss_reconstructed","GatedSAETrainer-0/loss_zero","GatedSAETrainer-0/mse_loss","GatedSAETrainer-0/sparsity_loss","GatedSAETrainer-1/aux_loss","GatedSAETrainer-1/cossim","GatedSAETrainer-1/frac_alive","GatedSAETrainer-1/frac_recovered","GatedSAETrainer-1/frac_variance_explained","GatedSAETrainer-1/l0","GatedSAETrainer-1/l1_loss","GatedSAETrainer-1/l2_loss","GatedSAETrainer-1/l2_ratio","GatedSAETrainer-1/loss","GatedSAETrainer-1/loss_original","GatedSAETrainer-1/loss_reconstructed","GatedSAETrainer-1/loss_zero","GatedSAETrainer-1/mse_loss","GatedSAETrainer-1/sparsity_loss"
2 | "faithful-yogurt-5","finished","-","","","2024-07-13T08:40:04.000Z","78097","","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","20.56","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","30","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","21934.6796875","0.9164952039718628","0.00004069010537932627","0.9665752053260804","0.974133312702179","24.208984375","766.954833984375","47.832576751708984","0.9202543497085572","32263.42578125","3.3303747177124023","3.799464702606201","17.364582061767578","2357.88525390625","385.6487121582031","17812.669921875","0.875127911567688","0.00004069010537932627","0.9306480288505554","0.9410187005996704","11.570068359375","473.6762390136719","57.49211883544922","0.8786675930023193","30521.861328125","3.26992130279541","4.263727188110352","17.59981346130371","3412.41162109375","309.806396484375"
3 | "absurd-armadillo-4","finished","-","","","2024-07-13T08:39:48.000Z","51872","","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","9.65","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","14.09","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","6471.3515625","0.9502999782562256","0.00004069010537932627","0.9877426028251648","0.9842393398284912","86.309326171875","1533.554443359375","37.37459945678711","0.955512762069702","18172.8984375","3.3303747177124023","3.5023977756500244","17.364582061767578","1438.06787109375","1062.67578125","8154.59375","0.939529836177826","0.00004069010537932627","0.982342541217804","0.9703220129013062","35.64306640625","704.9683227539062","40.845314025878906","0.9403757452964784","17816.7109375","3.26992130279541","3.522951364517212","17.59981346130371","1717.0882568359375","564.1171264648438"
4 | "solar-smoke-2","finished","-","","","2024-07-13T08:39:44.000Z","52969","","768","cuda:3","dictionary_learning.dictionary.GatedAutoEncoder","24576","1","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:3","dictionary_learning.dictionary.GatedAutoEncoder","24576","1.46","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","176.92796325683594","0.999501407146454","0.00004069010537932627","0.9996459484100342","0.9998784065246582","733.0106201171875","2288.524169921875","3.913240909576416","0.996691107749939","2142.02099609375","3.3303747177124023","3.335343599319458","17.364582061767578","15.964256286621094","1948.878173828125","357.5518798828125","0.9984543323516846","0.00004069010537932627","0.9995059967041016","0.9992587566375732","805.2034912109375","2221.821044921875","6.802419662475586","0.9969862699508668","3115.2333984375","3.26992130279541","3.2770004272460938","17.59981346130371","48.07682037353515","1850.6416015625"
5 | "super-pond-1","finished","-","","","2024-07-13T08:39:44.000Z","51669","","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","2.13","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","3.11","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","3244.373046875","0.9685524702072144","0.00004069010537932627","0.9927898645401","0.9899781346321106","523.408447265625","4199.6708984375","30.41826057434082","0.9841817617416382","10093.599609375","3.3303747177124023","3.431563138961792","17.364582061767578","969.7067260742188","2751.1318359375","1491.67626953125","0.9849917888641356","0.00004069010537932627","0.9976735711097716","0.9923139214515686","530.9619140625","1741.81884765625","20.90458297729492","0.987014651298523","6693.18212890625","3.26992130279541","3.3032593727111816","17.59981346130371","447.7106323242187","1513.2750244140625"
6 | "serene-plant-2","finished","-","","","2024-07-13T08:39:44.000Z","51594","","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","4.53","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","6.62","8","openai-community/gpt2","0.00005","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","3213.54931640625","0.964131236076355","0.00004069010537932627","0.9926445484161376","0.9886342883110046","308.4468994140625","2597.927734375","31.888023376464844","0.9749326705932616","12468.984375","3.3303747177124023","3.4336023330688477","17.364582061767578","1053.774658203125","1792.6346435546875","8968.025390625","0.9365889430046082","0.00004069010537932627","0.9807283878326416","0.96932053565979","95.2742919921875","1390.1123046875","41.98253631591797","0.941771388053894","17921.384765625","3.26992130279541","3.546081781387329","17.59981346130371","1815.49267578125","1141.9906005859375"
--------------------------------------------------------------------------------
/results/heaviside_softmax/experts16h.csv:
--------------------------------------------------------------------------------
1 | ,experts,heaviside,l0,mse_loss,frac_recovered,loss_original,loss_reconstructed,delta_ce
2 | 7,16.0,1.0,8.0,2643.10888671875,0.9520765542984008,3.26992130279541,3.9566588401794434,0.6867375373840332
3 | 5,16.0,1.0,32.0,1720.351806640625,0.9821312427520752,3.26992130279541,3.525978803634644,0.25605750083923384
4 | 3,16.0,1.0,64.0,1384.42578125,0.9890355467796326,3.26992130279541,3.4270408153533936,0.1571195125579834
5 | 1,16.0,1.0,128.0,1037.333251953125,0.9934381246566772,3.26992130279541,3.363952159881592,0.09403085708618164
6 |
--------------------------------------------------------------------------------
/results/heaviside_softmax/experts16s.csv:
--------------------------------------------------------------------------------
1 | ,experts,heaviside,l0,mse_loss,frac_recovered,loss_original,loss_reconstructed,delta_ce
2 | 6,16.0,0.0,8.0,2345.244384765625,0.9605113863945008,3.3303747177124023,3.884566307067871,0.5541915893554688
3 | 4,16.0,0.0,32.0,1412.8662109375,0.9869770407676696,3.3303747177124023,3.513141870498657,0.18276715278625444
4 | 2,16.0,0.0,64.0,1118.0126953125,0.9918711185455322,3.3303747177124023,3.444457054138184,0.1140823364257817
5 | 0,16.0,0.0,128.0,856.46142578125,0.9954178333282472,3.3303747177124023,3.394681453704834,0.06430673599243164
6 |
--------------------------------------------------------------------------------
/results/heaviside_softmax/experts_l0_deltace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/heaviside_softmax/experts_l0_deltace.png
--------------------------------------------------------------------------------
/results/heaviside_softmax/experts_l0_lossrec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/heaviside_softmax/experts_l0_lossrec.png
--------------------------------------------------------------------------------
/results/heaviside_softmax/experts_l0_mse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/heaviside_softmax/experts_l0_mse.png
--------------------------------------------------------------------------------
/results/heaviside_softmax/primary-experts.csv:
--------------------------------------------------------------------------------
1 | "Name","State","Notes","User","Tags","Created","Runtime","Sweep","SwitchAutoEncoder-0.activation_dim","SwitchAutoEncoder-0.auxk_alpha","SwitchAutoEncoder-0.decay_start","SwitchAutoEncoder-0.device","SwitchAutoEncoder-0.dict_class","SwitchAutoEncoder-0.dict_size","SwitchAutoEncoder-0.experts","SwitchAutoEncoder-0.heaviside","SwitchAutoEncoder-0.k","SwitchAutoEncoder-0.layer","SwitchAutoEncoder-0.lm_name","SwitchAutoEncoder-0.seed","SwitchAutoEncoder-0.steps","SwitchAutoEncoder-0.trainer","SwitchAutoEncoder-0.wandb_name","SwitchAutoEncoder-1.activation_dim","SwitchAutoEncoder-1.auxk_alpha","SwitchAutoEncoder-1.decay_start","SwitchAutoEncoder-1.device","SwitchAutoEncoder-1.dict_class","SwitchAutoEncoder-1.dict_size","SwitchAutoEncoder-1.experts","SwitchAutoEncoder-1.heaviside","SwitchAutoEncoder-1.k","SwitchAutoEncoder-1.layer","SwitchAutoEncoder-1.lm_name","SwitchAutoEncoder-1.seed","SwitchAutoEncoder-1.steps","SwitchAutoEncoder-1.trainer","SwitchAutoEncoder-1.wandb_name","SwitchAutoEncoder-0/auxk_loss","SwitchAutoEncoder-0/cossim","SwitchAutoEncoder-0/dead_features","SwitchAutoEncoder-0/effective_l0","SwitchAutoEncoder-0/frac_alive","SwitchAutoEncoder-0/frac_recovered","SwitchAutoEncoder-0/frac_variance_explained","SwitchAutoEncoder-0/l0","SwitchAutoEncoder-0/l1_loss","SwitchAutoEncoder-0/l2_loss","SwitchAutoEncoder-0/l2_ratio","SwitchAutoEncoder-0/loss","SwitchAutoEncoder-0/loss_original","SwitchAutoEncoder-0/loss_reconstructed","SwitchAutoEncoder-0/loss_zero","SwitchAutoEncoder-0/mse_loss","SwitchAutoEncoder-1/auxk_loss","SwitchAutoEncoder-1/cossim","SwitchAutoEncoder-1/dead_features","SwitchAutoEncoder-1/effective_l0","SwitchAutoEncoder-1/frac_alive","SwitchAutoEncoder-1/frac_recovered","SwitchAutoEncoder-1/frac_variance_explained","SwitchAutoEncoder-1/l0","SwitchAutoEncoder-1/l1_loss","SwitchAutoEncoder-1/l2_loss","SwitchAutoEncoder-1/l2_ratio","SwitchAutoEncoder-1/loss","SwitchAutoEncoder-1/loss_original","SwitchAutoEncoder-1/loss_reconstructed","SwitchAutoEncoder-1/loss_zero","SwitchAutoEncoder-1/mse_loss"
2 | "leafy-frost-5","finished","-","","","2024-07-13T06:09:56.000Z","44661","","768","0.03125","80000","cuda:4","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","false","128","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:4","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","true","128","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","0.00003293702320661396","0.970380425453186","2266","128","0.00004069010537932627","0.9954178333282472","0.9893406629562378","346.756103515625","929.089599609375","27.995044708251953","0.970288336277008","846.9383544921875","3.3303747177124023","3.394681453704834","17.364582061767578","856.46142578125","0","0.9639195203781128","0","128","0.00004069010537932627","0.9934381246566772","0.987186849117279","344.84912109375","932.871826171875","31.138887405395508","0.9640032052993774","1024.9150390625","3.26992130279541","3.363952159881592","17.59981346130371","1037.333251953125"
3 | "winter-moon-4","finished","-","","","2024-07-13T06:09:42.000Z","44679","","768","0.03125","80000","cuda:7","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","false","64","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:7","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","true","64","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","0.00003745846333913505","0.9611231684684752","2292","64","0.00004069010537932627","0.9918711185455322","0.9860854744911194","221.5013427734375","616.9510498046875","32.06147384643555","0.9609488248825072","1105.567626953125","3.3303747177124023","3.4444570541381836","17.364582061767578","1118.0126953125","0","0.9514711499214172","0","64","0.00004069010537932627","0.9890355467796326","0.9828994274139404","234.9410400390625","667.9144287109375","35.96377182006836","0.9516751170158386","1366.359130859375","3.26992130279541","3.4270408153533936","17.59981346130371","1384.42578125"
4 | "likely-aardvark-3","finished","-","","","2024-07-13T06:09:22.000Z","43728","","768","0.03125","80000","cuda:6","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","false","32","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:6","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","true","32","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","0.00004270674253348261","0.9505878686904908","3495","32","0.00004069010537932627","0.9869770407676696","0.9824159145355223","147.52978515625","473.6759033203125","36.09650421142578","0.9501758217811584","1390.8076171875","3.3303747177124023","3.5131418704986572","17.364582061767578","1412.8662109375","0","0.9392642974853516","0","32","0.00004069010537932627","0.9821312427520752","0.9787499904632568","158.9560546875","509.13079833984375","40.07820129394531","0.9391745328903198","1695.3155517578125","3.26992130279541","3.5259788036346436","17.59981346130371","1720.351806640625"
5 | "genial-dawn-2","finished","-","","","2024-07-13T06:09:08.000Z","46116","","768","0.03125","80000","cuda:5","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","false","8","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","768","0.03125","80000","cuda:5","dictionary_learning.trainers.switch.SwitchAutoEncoder","24576","16","true","8","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.switch.SwitchTrainer","SwitchAutoEncoder","0.00006340954860206693","0.91660076379776","14566","8","0.00004069010537932627","0.9605113863945008","0.970811665058136","115.51416015625","384.5042419433594","46.83196258544922","0.9166349172592164","2320.48779296875","3.3303747177124023","3.884566307067871","17.364582061767578","2345.244384765625","0.00006867974298074841","0.905162513256073","6606","8","0.00004069010537932627","0.9520765542984008","0.967352032661438","137.64208984375","417.4818420410156","49.77260971069336","0.905208945274353","2604.93310546875","3.26992130279541","3.9566588401794434","17.59981346130371","2643.10888671875"
--------------------------------------------------------------------------------
/results/load_balance/alpha_deltace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/load_balance/alpha_deltace.png
--------------------------------------------------------------------------------
/results/load_balance/alpha_lossrec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/load_balance/alpha_lossrec.png
--------------------------------------------------------------------------------
/results/load_balance/alpha_mse.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/load_balance/alpha_mse.pdf
--------------------------------------------------------------------------------
/results/load_balance/alpha_mse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/load_balance/alpha_mse.png
--------------------------------------------------------------------------------
/results/primary/flopmatch_l0_deltace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/flopmatch_l0_deltace.png
--------------------------------------------------------------------------------
/results/primary/flopmatch_l0_lossrec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/flopmatch_l0_lossrec.png
--------------------------------------------------------------------------------
/results/primary/flopmatch_l0_mse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/flopmatch_l0_mse.png
--------------------------------------------------------------------------------
/results/primary/flopmatch_mse_deltace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/flopmatch_mse_deltace.png
--------------------------------------------------------------------------------
/results/primary/l0_deltace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/l0_deltace.png
--------------------------------------------------------------------------------
/results/primary/l0_lossrec.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/l0_lossrec.png
--------------------------------------------------------------------------------
/results/primary/l0_mse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/l0_mse.png
--------------------------------------------------------------------------------
/results/primary/mse_deltace.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/mse_deltace.png
--------------------------------------------------------------------------------
/results/primary/primary-gated-clean.csv:
--------------------------------------------------------------------------------
1 | ,l0,mse_loss,frac_recovered,delta_ce
2 | 3,8.206787109375,3492.701171875,0.9255402088165284,1.067000389099121
3 | 2,12.566650390625,2559.174560546875,0.9579382538795472,0.5903034210205078
4 | 5,23.200927734375,1837.4671630859373,0.9785899519920348,0.3068037033081059
5 | 4,48.8880615234375,1253.7666015625,0.9901224970817566,0.13862347602844238
6 | 7,82.66162109375,950.4469604492188,0.99338698387146,0.09476351737976074
7 | 1,104.4046630859375,870.0953979492188,0.9947134852409364,0.07575511932373047
8 | 0,127.900146484375,791.554931640625,0.9953534007072448,0.06521177291870117
9 | 6,159.738525390625,710.6533203125,0.9960606098175048,0.05528640747070268
10 | 9,289.6943359375,508.3369140625,0.9972550868988036,0.03933382034301802
--------------------------------------------------------------------------------
/results/primary/primary-gated.csv:
--------------------------------------------------------------------------------
1 | "Name","State","Notes","User","Tags","Created","Runtime","Sweep","GatedSAETrainer-0.activation_dim","GatedSAETrainer-0.device","GatedSAETrainer-0.dict_class","GatedSAETrainer-0.dict_size","GatedSAETrainer-0.l1_penalty","GatedSAETrainer-0.layer","GatedSAETrainer-0.lm_name","GatedSAETrainer-0.lr","GatedSAETrainer-0.seed","GatedSAETrainer-0.trainer","GatedSAETrainer-0.wandb_name","GatedSAETrainer-0.warmup_steps","GatedSAETrainer-1.activation_dim","GatedSAETrainer-1.device","GatedSAETrainer-1.dict_class","GatedSAETrainer-1.dict_size","GatedSAETrainer-1.l1_penalty","GatedSAETrainer-1.layer","GatedSAETrainer-1.lm_name","GatedSAETrainer-1.lr","GatedSAETrainer-1.seed","GatedSAETrainer-1.trainer","GatedSAETrainer-1.wandb_name","GatedSAETrainer-1.warmup_steps","GatedSAETrainer-0/aux_loss","GatedSAETrainer-0/cossim","GatedSAETrainer-0/frac_alive","GatedSAETrainer-0/frac_recovered","GatedSAETrainer-0/frac_variance_explained","GatedSAETrainer-0/l0","GatedSAETrainer-0/l1_loss","GatedSAETrainer-0/l2_loss","GatedSAETrainer-0/l2_ratio","GatedSAETrainer-0/loss","GatedSAETrainer-0/loss_original","GatedSAETrainer-0/loss_reconstructed","GatedSAETrainer-0/loss_zero","GatedSAETrainer-0/mse_loss","GatedSAETrainer-0/sparsity_loss","GatedSAETrainer-1/aux_loss","GatedSAETrainer-1/cossim","GatedSAETrainer-1/frac_alive","GatedSAETrainer-1/frac_recovered","GatedSAETrainer-1/frac_variance_explained","GatedSAETrainer-1/l0","GatedSAETrainer-1/l1_loss","GatedSAETrainer-1/l2_loss","GatedSAETrainer-1/l2_ratio","GatedSAETrainer-1/loss","GatedSAETrainer-1/loss_original","GatedSAETrainer-1/loss_reconstructed","GatedSAETrainer-1/loss_zero","GatedSAETrainer-1/mse_loss","GatedSAETrainer-1/sparsity_loss"
2 | "rare-firefly-6","finished","-","","","2024-07-11T18:51:52.000Z","51591","","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","5","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","6","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1163.14306640625","0.973246932029724","0.00004069010537932627","0.9953534007072448","0.9913633465766908","127.900146484375","528.8916015625","27.75778579711914","0.9604530334472656","3927.76611328125","3.3303747177124023","3.3955864906311035","17.364582061767578","791.554931640625","394.5193481445313","1230.1353759765625","0.9700021743774414","0.00004069010537932627","0.9947134852409364","0.984986126422882","104.4046630859375","507.0380859375","29.08409881591797","0.9812511205673218","4367.89501953125","3.26992130279541","3.3456764221191406","17.59981346130371","870.0953979492188","376.71624755859375"
3 | "dashing-lion-5","finished","-","","","2024-07-11T05:25:59.000Z","50846","","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","20.56","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:7","dictionary_learning.dictionary.GatedAutoEncoder","24576","30","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","3559.26806640625","0.9084912538528442","0.00004069010537932627","0.9579382538795472","0.9719541668891908","12.566650390625","378.3607177734375","49.76805114746094","0.895888090133667","11920.619140625","3.3303747177124023","3.92067813873291","17.364582061767578","2559.174560546875","280.8070068359375","4629.88525390625","0.8717569708824158","0.00004069010537932627","0.9255402088165284","0.9396619200706482","8.206787109375","312.36065673828125","58.12748718261719","0.8806691765785217","15425.427734375","3.26992130279541","4.336921691894531","17.59981346130371","3492.701171875","243.12893676757812"
4 | "graceful-totem-4","finished","-","","","2024-07-11T05:25:45.000Z","51864","","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","9.65","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","14.09","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1701.7183837890625","0.9564532041549684","0.00004069010537932627","0.9901224970817566","0.9862537980079652","48.8880615234375","440.683349609375","34.85649871826172","0.9686530232429504","6070.19482421875","3.3303747177124023","3.4689981937408447","17.364582061767578","1253.7666015625","322.1306457519531","2614.995361328125","0.934950828552246","0.00004069010537932627","0.9785899519920348","0.968268096446991","23.200927734375","387.74371337890625","42.21244812011719","0.9180249571800232","8781.263671875","3.26992130279541","3.5767250061035156","17.59981346130371","1837.4671630859375","305.5757751464844"
5 | "silver-puddle-3","finished","-","","","2024-07-11T05:25:35.000Z","51579","","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","4.53","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:5","dictionary_learning.dictionary.GatedAutoEncoder","24576","6.62","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","1059.000244140625","0.975863754749298","0.00004069010537932627","0.9960606098175048","0.9922483563423156","159.738525390625","578.526123046875","26.311534881591797","0.9801684617996216","3686.526611328125","3.3303747177124023","3.3856611251831055","17.364582061767578","710.6533203125","420.4597473144531","1368.194091796875","0.9670422077178956","0.00004069010537932627","0.99338698387146","0.9835944175720216","82.66162109375","464.7086181640625","30.388513565063477","0.9570202827453612","4679.21875","3.26992130279541","3.364684820175171","17.59981346130371","950.4469604492188","356.64434814453125"
6 | "generous-fog-2","finished","-","","","2024-07-11T05:25:20.000Z","51422","","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","2.13","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:4","dictionary_learning.dictionary.GatedAutoEncoder","24576","3.11","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","603.53125","0.9911146759986876","0.00004069010537932627","0.998424232006073","0.9973034858703612","529.562744140625","1140.15185546875","16.387954711914062","0.993000864982605","2631.47314453125","3.3303747177124023","3.352489471435547","17.364582061767578","278.48040771484375","825.9996337890625","819.4901123046875","0.9827598333358764","0.00004069010537932627","0.9972550868988036","0.9914193153381348","289.6943359375","723.2354736328125","22.326417922973633","0.9679235219955444","3019.125","3.26992130279541","3.3092551231384277","17.59981346130371","508.3369140625","546.3135986328125"
7 | "pleasant-oath-1","finished","-","","","2024-07-11T05:21:07.000Z","52028","","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","1","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","768","cuda:6","dictionary_learning.dictionary.GatedAutoEncoder","24576","1.46","8","openai-community/gpt2","0.0003","0","dictionary_learning.trainers.gdm.GatedSAETrainer","GatedSAETrainer","1000","245.25555419921875","0.9967520833015442","0.00004069010537932627","0.9992234110832214","0.9996508359909058","645.68798828125","1738.0128173828125","9.955805778503418","0.995638370513916","1791.7958984375","3.3303747177124023","3.341273546218872","17.364582061767578","102.31087493896484","1449.95556640625","388.83349609375","0.9960089921951294","0.00004069010537932627","0.9988710284233092","0.9988107085227966","565.810791015625","1553.6287841796875","10.871599197387695","0.9927268028259276","2247.65771484375","3.26992130279541","3.2860991954803467","17.59981346130371","119.66344451904295","1189.346923828125"
--------------------------------------------------------------------------------
/results/primary/primary-relu-clean.csv:
--------------------------------------------------------------------------------
1 | ,l0,mse_loss,frac_recovered,delta_ce
2 | 5,15.46533203125,5551.58984375,0.900568962097168,1.4248356819152832
3 | 3,15.949462890625,3971.114501953125,0.9313362836837769,0.9839439392089844
4 | 9,16.97021484375,3336.4921875,0.9437114000320436,0.8066086769104004
5 | 11,21.419921875,2474.707763671875,0.9623249769210817,0.5398788452148438
6 | 13,31.735595703125,1932.8282470703125,0.974922776222229,0.35935378074645996
7 | 12,44.6102294921875,1713.539306640625,0.979756772518158,0.28409790992736816
8 | 15,53.3970947265625,1546.72998046875,0.9833840727806092,0.23810505867004395
9 | 14,75.8472900390625,1370.767333984375,0.9866704344749452,0.18707013130187988
10 | 17,101.04296875,1230.6522216796875,0.9891743063926696,0.15513181686401367
11 | 16,150.755615234375,1069.6497802734375,0.9915485978126526,0.11860823631286666
12 | 1,213.8397216796875,946.8670654296876,0.9934131503105164,0.09438896179199219
--------------------------------------------------------------------------------
/results/primary/primary-topk-clean.csv:
--------------------------------------------------------------------------------
1 | ,l0,mse_loss,frac_recovered,delta_ce
2 | 2,8.0,2219.68701171875,0.9644190669059752,0.49935078620910645
3 | 3,16.0,1562.1041259765625,0.9841692447662354,0.2268538475036621
4 | 4,32.0,1142.533447265625,0.9910413026809692,0.12572836875915527
5 | 5,48.0,991.705810546875,0.9935886263847352,0.0918736457824707
6 | 6,64.0,888.516845703125,0.9945665597915648,0.07625389099121094
7 | 7,96.0,761.394287109375,0.9957606792449952,0.060749053955078125
8 | 0,128.0,661.279541015625,0.9967961311340332,0.04496359825134277
9 | 1,192.0,522.2657470703125,0.9979140758514404,0.029891014099121094
10 |
--------------------------------------------------------------------------------
/results/primary/primary-topk.csv:
--------------------------------------------------------------------------------
1 | "Name","State","Notes","User","Tags","Created","Runtime","Sweep","AutoEncoderTopK-0.activation_dim","AutoEncoderTopK-0.auxk_alpha","AutoEncoderTopK-0.decay_start","AutoEncoderTopK-0.device","AutoEncoderTopK-0.dict_class","AutoEncoderTopK-0.dict_size","AutoEncoderTopK-0.k","AutoEncoderTopK-0.layer","AutoEncoderTopK-0.lm_name","AutoEncoderTopK-0.seed","AutoEncoderTopK-0.steps","AutoEncoderTopK-0.trainer","AutoEncoderTopK-0.wandb_name","AutoEncoderTopK-1.activation_dim","AutoEncoderTopK-1.auxk_alpha","AutoEncoderTopK-1.decay_start","AutoEncoderTopK-1.device","AutoEncoderTopK-1.dict_class","AutoEncoderTopK-1.dict_size","AutoEncoderTopK-1.k","AutoEncoderTopK-1.layer","AutoEncoderTopK-1.lm_name","AutoEncoderTopK-1.seed","AutoEncoderTopK-1.steps","AutoEncoderTopK-1.trainer","AutoEncoderTopK-1.wandb_name","AutoEncoderTopK-0/auxk_loss","AutoEncoderTopK-0/cossim","AutoEncoderTopK-0/dead_features","AutoEncoderTopK-0/effective_l0","AutoEncoderTopK-0/frac_alive","AutoEncoderTopK-0/frac_recovered","AutoEncoderTopK-0/frac_variance_explained","AutoEncoderTopK-0/l0","AutoEncoderTopK-0/l1_loss","AutoEncoderTopK-0/l2_loss","AutoEncoderTopK-0/l2_ratio","AutoEncoderTopK-0/loss","AutoEncoderTopK-0/loss_original","AutoEncoderTopK-0/loss_reconstructed","AutoEncoderTopK-0/loss_zero","AutoEncoderTopK-0/mse_loss","AutoEncoderTopK-1/auxk_loss","AutoEncoderTopK-1/cossim","AutoEncoderTopK-1/dead_features","AutoEncoderTopK-1/effective_l0","AutoEncoderTopK-1/frac_alive","AutoEncoderTopK-1/frac_recovered","AutoEncoderTopK-1/frac_variance_explained","AutoEncoderTopK-1/l0","AutoEncoderTopK-1/l1_loss","AutoEncoderTopK-1/l2_loss","AutoEncoderTopK-1/l2_ratio","AutoEncoderTopK-1/loss","AutoEncoderTopK-1/loss_original","AutoEncoderTopK-1/loss_reconstructed","AutoEncoderTopK-1/loss_zero","AutoEncoderTopK-1/mse_loss"
2 | "royal-water-4","finished","-","","","2024-07-14T00:29:11.000Z","45132","","768","0.03125","80000","cuda:7","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","128","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","768","0.03125","80000","cuda:7","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","192","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","2.7579562811297365e-7","0.9772376418113708","3","128","0.00004069010537932627","0.9967961311340332","0.9927451014518738","500.12890625","850.615234375","25.00290298461914","0.976984202861786","660.1929931640625","3.3303747177124023","3.375338315963745","17.364582061767578","661.279541015625","0.0000015463250520042493","0.9819841384887696","18","192","0.00004069010537932627","0.9979140758514404","0.990972638130188","654.5338134765625","987.9059448242188","22.222896575927734","0.9815428256988524","517.9833374023438","3.26992130279541","3.2998123168945312","17.59981346130371","522.2657470703125"
3 | "drawn-voice-3","finished","-","","","2024-07-14T00:28:27.000Z","50754","","768","0.03125","80000","cuda:4","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","8","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","768","0.03125","80000","cuda:4","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","16","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","0.00014197103155311197","0.9215346574783324","17839","8","0.00004069010537932627","0.9644190669059752","0.9756476283073424","1591.396484375","1681.16650390625","45.909427642822266","0.92127525806427","2218.64794921875","3.3303747177124023","3.829725503921509","17.364582061767578","2219.68701171875","0.00005811716982861981","0.9451953768730164","9701","16","0.00004069010537932627","0.9841692447662354","0.972999393939972","317.203857421875","534.530029296875","38.529930114746094","0.9450527429580688","1550.5992431640625","3.26992130279541","3.4967751502990723","17.59981346130371","1562.1041259765625"
4 | "still-river-1","finished","-","","","2024-07-14T00:28:26.000Z","45766","","768","0.03125","80000","cuda:5","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","32","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","768","0.03125","80000","cuda:5","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","48","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","0.00004375946446089074","0.9603254795074464","1118","32","0.00004069010537932627","0.9910413026809692","0.9874651432037354","201.1285400390625","508.3766479492187","32.90629577636719","0.9599528908729552","1139.227294921875","3.3303747177124023","3.4561030864715576","17.364582061767578","1142.533447265625","0.0000035865712106897263","0.9654753804206848","32","48","0.00004069010537932627","0.9935886263847352","0.9828585386276244","245.3511962890625","538.3258666992188","30.67569351196289","0.9650439023971558","987.57861328125","3.26992130279541","3.361794948577881","17.59981346130371","991.705810546875"
5 | "vivid-bee-2","finished","-","","","2024-07-14T00:28:26.000Z","44044","","768","0.03125","80000","cuda:6","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","64","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","768","0.03125","80000","cuda:6","dictionary_learning.trainers.top_k.AutoEncoderTopK","24576","96","8","openai-community/gpt2","0","100000","dictionary_learning.trainers.top_k.TrainerTopK","AutoEncoderTopK","8.166377369889233e-7","0.9692695140838624","8","64","0.00004069010537932627","0.9945665597915648","0.9902520775794984","302.2916259765625","629.1004638671875","29.00596618652344","0.968896210193634","887.1822509765625","3.3303747177124023","3.4066286087036133","17.364582061767578","888.516845703125","5.727280267819879e-7","0.9736075401306152","6","96","0.00004069010537932627","0.9957606792449952","0.9868393540382384","401.2947998046875","703.3333740234375","26.865930557250977","0.9731273651123048","756.0817260742188","3.26992130279541","3.3306703567504883","17.59981346130371","761.394287109375"
--------------------------------------------------------------------------------
/results/primary/switch_sae_l0_lr.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/switch_sae_l0_lr.pdf
--------------------------------------------------------------------------------
/results/primary/switch_sae_l0_mse.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/switch_sae_l0_mse.pdf
--------------------------------------------------------------------------------
/results/primary/switch_sae_pareto.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/switch_sae_pareto.pdf
--------------------------------------------------------------------------------
/results/primary/switch_sae_pareto_flop.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/switch_sae_pareto_flop.pdf
--------------------------------------------------------------------------------
/results/primary/switch_sae_pareto_width.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amudide/switch_sae/1b484570c74063b8331cb658a6bbf1964048a7de/results/primary/switch_sae_pareto_width.pdf
--------------------------------------------------------------------------------
/save_activations.py:
--------------------------------------------------------------------------------
1 | #%%
2 |
3 | from nnsight import LanguageModel
4 | from dictionary_learning.utils import hf_dataset_to_generator
5 | from config import lm, activation_dim, layer, hf, n_ctxs
6 | import torch as t
7 | import einops
8 | from tqdm import tqdm
9 | import os
10 |
11 | t.set_grad_enabled(False)
12 |
13 | # %%
14 |
15 |
16 | device = f'cuda:1'
17 | model = LanguageModel(lm, dispatch=True, device_map=device)
18 | submodule = model.transformer.h[layer]
19 | data = hf_dataset_to_generator(hf)
20 |
21 | # %%
22 | batch_size = 256
23 | num_batches = 128
24 | ctx_len = 128
25 |
26 | total_tokens = batch_size * num_batches * ctx_len
27 | total_memory = total_tokens * activation_dim * 4
28 | print(f"Total contexts: {batch_size * num_batches / 1e3:.2f}K")
29 | print(f"Total tokens: {total_tokens / 1e6:.2f}M")
30 | print(f"Total memory: {total_memory / 1e9:.2f}GB")
31 |
32 | # %%
33 |
34 | # These functions copied from buffer.py
35 |
36 | def text_batch():
37 | return [
38 | next(data) for _ in range(batch_size)
39 | ]
40 |
41 | def tokenized_batch():
42 | texts = text_batch()
43 | return model.tokenizer(
44 | texts,
45 | return_tensors='pt',
46 | max_length=ctx_len,
47 | padding=True,
48 | truncation=True
49 | )
50 |
51 | def get_activations(input):
52 | with t.no_grad():
53 | with model.trace(input):
54 | hidden_states = submodule.output.save()
55 | hidden_states = hidden_states.value
56 | if isinstance(hidden_states, tuple):
57 | hidden_states = hidden_states[0]
58 | hidden_states = hidden_states[input['attention_mask'] != 0]
59 | return hidden_states
60 |
61 |
62 |
63 | # %%
64 |
65 | all_activations = []
66 | all_tokens = []
67 |
68 | for _ in tqdm(range(num_batches)):
69 | batch = tokenized_batch()
70 | all_tokens.append(batch['input_ids'].cpu())
71 | activations = get_activations(batch)
72 | activations = einops.rearrange(activations, "(b c) d -> b c d", b=batch_size)
73 | all_activations.append(activations.cpu())
74 |
75 | # %%
76 |
77 | concatenated_activations = t.cat(all_activations)
78 | concatenated_tokens = t.cat(all_tokens)
79 | print(concatenated_activations.shape, concatenated_tokens.shape)
80 |
81 | # %%
82 |
83 | # save activations
84 | os.makedirs('data', exist_ok=True)
85 | t.save(concatenated_activations, f'data/gpt2_activations_layer{layer}.pt')
86 | t.save(concatenated_tokens, f'data/gpt2_tokens.pt')
--------------------------------------------------------------------------------
/scaling_laws/attempt0/train-dense-topk.py:
--------------------------------------------------------------------------------
1 | """
2 | Trains dense TopK SAEs of varying scale.
3 | """
4 |
5 | import os
6 | import sys
7 | sys.path.append( # the switch_sae directory
8 | os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9 | )
10 | os.environ['HF_HOME'] = '/om/user/ericjm/.cache/huggingface'
11 |
12 |
13 | from nnsight import LanguageModel
14 | import torch as t
15 | from dictionary_learning import ActivationBuffer
16 | from dictionary_learning.training import trainSAE
17 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename
18 | from dictionary_learning.trainers.top_k import AutoEncoderTopK, TrainerTopK
19 | from dictionary_learning.evaluation import evaluate
20 | import wandb
21 |
22 |
23 | lm = 'openai-community/gpt2'
24 | activation_dim = 768
25 | layer = 8
26 | hf = 'Skylion007/openwebtext'
27 | steps = 100_000 # test run!
28 | n_ctxs = int(1e4)
29 |
30 | dict_sizes = [2048, 4096, 8192, 16384, 32768, 65536, 131072]
31 | k = 32
32 |
33 | save_dir = "topk_dense"
34 |
35 | device = f'cuda:0'
36 | model = LanguageModel(lm, dispatch=True, device_map=device)
37 | submodule = model.transformer.h[layer]
38 | data = hf_dataset_to_generator(hf)
39 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device)
40 |
41 |
42 | base_trainer_config = {
43 | 'trainer' : TrainerTopK,
44 | 'dict_class' : AutoEncoderTopK,
45 | 'activation_dim' : activation_dim,
46 | 'k' : k,
47 | 'auxk_alpha' : 1/32,
48 | 'decay_start' : int(steps * 0.8),
49 | 'steps' : steps,
50 | 'seed' : 0,
51 | 'device' : device,
52 | 'layer' : layer,
53 | 'lm_name' : lm,
54 | 'wandb_name' : 'AutoEncoderTopK'
55 | }
56 |
57 | trainer_configs = [(base_trainer_config | {'dict_size': ds}) for ds in dict_sizes]
58 |
59 | wandb.init(entity="ericjmichaud_", project="switch_saes_scaling_laws_attempt_0", config={f"{trainer_config['wandb_name']}-{i}" : trainer_config for i, trainer_config in enumerate(trainer_configs)})
60 |
61 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir=save_dir, log_steps=1, steps=steps)
62 |
63 | print("Training finished. Evaluating SAE...", flush=True)
64 | for i, trainer_config in enumerate(trainer_configs):
65 | ae = AutoEncoderTopK.from_pretrained(
66 | os.path.join(save_dir, f'{cfg_filename(trainer_config)}/ae.pt'),
67 | k=trainer_config['k'],
68 | device=device
69 | )
70 | metrics = evaluate(ae, buffer, device=device)
71 | log = {}
72 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()})
73 | wandb.log(log, step=steps+1)
74 | wandb.finish()
75 |
--------------------------------------------------------------------------------
/scaling_laws/attempt0/train-dense-topk.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=dtk0
3 | #SBATCH --partition=tegmark
4 | #SBATCH --ntasks=1
5 | #SBATCH --mem=16GB
6 | #SBATCH --gres=gpu:a100:1
7 | #SBATCH --time=2-00:00:00
8 | #SBATCH --output=/om2/user/ericjm/switch_sae/scaling_laws/attempt0/slurm-%j.out
9 |
10 | python /om2/user/ericjm/switch_sae/scaling_laws/attempt0/train-dense-topk.py
11 |
--------------------------------------------------------------------------------
/scaling_laws/attempt0/train-switch-topk.py:
--------------------------------------------------------------------------------
1 | """
2 | Trains Switch SAEs of varying scale.
3 | """
4 |
5 | import os
6 | import sys
7 | sys.path.append( # the switch_sae directory
8 | os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
9 | )
10 | os.environ['HF_HOME'] = '/om/user/ericjm/.cache/huggingface'
11 |
12 |
13 | from nnsight import LanguageModel
14 | import torch as t
15 | from dictionary_learning import ActivationBuffer
16 | from dictionary_learning.training import trainSAE
17 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename, str2bool
18 | from dictionary_learning.trainers.switch import SwitchAutoEncoder, SwitchTrainer
19 | from dictionary_learning.evaluation import evaluate
20 | import wandb
21 | import argparse
22 | # import itertools
23 | # from config import lm, activation_dim, layer, hf, steps, n_ctxs
24 |
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--num_experts", type=int, default=8, required=True)
27 | parser.add_argument("--lb_alpha", type=float, default=3.0)
28 | parser.add_argument("--heaviside", type=str2bool, default=False)
29 | args = parser.parse_args()
30 |
31 | lm = 'openai-community/gpt2'
32 | activation_dim = 768
33 | layer = 8
34 | hf = 'Skylion007/openwebtext'
35 | steps = 100_000
36 | n_ctxs = int(1e4)
37 |
38 | dict_sizes = [2048, 4096, 8192, 16384, 32768, 65536, 131072]
39 | k = 32
40 |
41 | save_dir = f"topk_switch{args.num_experts}"
42 |
43 | device = f'cuda:0'
44 | model = LanguageModel(lm, dispatch=True, device_map=device)
45 | submodule = model.transformer.h[layer]
46 | data = hf_dataset_to_generator(hf)
47 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device)
48 |
49 | base_trainer_config = {
50 | 'trainer' : SwitchTrainer,
51 | 'dict_class' : SwitchAutoEncoder,
52 | 'activation_dim' : activation_dim,
53 | 'k': k,
54 | 'experts' : args.num_experts,
55 | 'lb_alpha' : args.lb_alpha,
56 | 'heaviside' : args.heaviside,
57 | 'auxk_alpha' : 1/32,
58 | 'decay_start' : int(steps * 0.8),
59 | 'steps' : steps,
60 | 'seed' : 0,
61 | 'device' : device,
62 | 'layer' : layer,
63 | 'lm_name' : lm,
64 | 'wandb_name' : 'SwitchAutoEncoder'
65 | }
66 |
67 | trainer_configs = [(base_trainer_config | {'dict_size': ds}) for ds in dict_sizes]
68 | # {'k': combo[0], 'experts': combo[1], 'heaviside': combo[2], 'lb_alpha': combo[3]}) for combo in itertools.product(args.ks, args.num_experts, args.heavisides, args.lb_alphas)]
69 |
70 | wandb.init(entity="ericjmichaud_", project="switch_saes_scaling_laws_attempt_0", config={f"{trainer_config['wandb_name']}-{i}" : trainer_config for i, trainer_config in enumerate(trainer_configs)})
71 | # wandb.init(entity="amudide", project="Switch (LB)", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)})
72 |
73 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir=save_dir, log_steps=1, steps=steps)
74 |
75 | print("Training finished. Evaluating SAE...", flush=True)
76 | for i, trainer_config in enumerate(trainer_configs):
77 | ae = SwitchAutoEncoder.from_pretrained(
78 | os.path.join(save_dir, f'{cfg_filename(trainer_config)}/ae.pt'),
79 | k = trainer_config['k'],
80 | experts = trainer_config['experts'],
81 | heaviside = trainer_config['heaviside'],
82 | device=device
83 | )
84 | metrics = evaluate(ae, buffer, device=device)
85 | log = {}
86 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()})
87 | wandb.log(log, step=steps+1)
88 | wandb.finish()
89 |
--------------------------------------------------------------------------------
/scaling_laws/attempt0/train-switch-topk.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=dtk0
3 | #SBATCH --partition=tegmark
4 | #SBATCH --ntasks=1
5 | #SBATCH --mem=16GB
6 | #SBATCH --gres=gpu:a100:1
7 | #SBATCH --time=2-00:00:00
8 | #SBATCH --output=/om2/user/ericjm/switch_sae/scaling_laws/attempt0/slurm-%A_%a.out
9 | #SBATCH --array=8,64
10 |
11 | python /om2/user/ericjm/switch_sae/scaling_laws/attempt0/train-switch-topk.py --num_experts $SLURM_ARRAY_TASK_ID
12 |
--------------------------------------------------------------------------------
/table/create_table_script.py:
--------------------------------------------------------------------------------
1 | # %%
2 |
3 | import itertools
4 |
5 | # Generate all combinations of parameters
6 | layers = [2, 4, 6, 8, 10]
7 | sae_types = ['switch', 'topk']
8 | types = ['resid', 'attn', 'mlp']
9 | devices = [f'cuda:{i}' for i in range(8)]
10 |
11 | # First set of commands for GPT-2
12 | gpt2_commands = []
13 | for layer, sae_type, type_ in itertools.product(layers, sae_types, types):
14 | cmd = f"python3 table/train_switch_table.py --device cuda:{len(gpt2_commands) % 8} --layer {layer} --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type {type_} --sae_type {sae_type} --steps 20000 &"
15 | gpt2_commands.append(cmd)
16 |
17 | # Second set of commands for Gemma-2B
18 | gemma_commands = []
19 | for i, sae_type in enumerate(sae_types):
20 | cmd = f"python3 table/train_switch_table.py --device cuda:{i} --layer 12 --lm google/gemma-2b --ks 64 --activation_dim 2048 --dict_ratio 32 --num_experts 8 --type resid --sae_type {sae_type} --steps 2000000 &"
21 | gemma_commands.append(cmd)
22 |
23 | # Write commands to a bash script
24 | with open('run_parallel.sh', 'w') as f:
25 | f.write('#!/bin/bash\n\n')
26 |
27 | # Write GPT-2 commands in batches of 8
28 | f.write('# GPT-2 training commands\n')
29 | for i, cmd in enumerate(gpt2_commands):
30 | f.write(f'{cmd}\n')
31 | if (i + 1) % 8 == 0:
32 | f.write('\nwait\n\n') # Wait after every 8 commands
33 |
34 | # If there are remaining commands that don't make a full batch of 8
35 | if len(gpt2_commands) % 8:
36 | f.write('\nwait\n\n')
37 |
38 | # Write Gemma commands
39 | f.write('# Gemma-2B training commands\n')
40 | for i, cmd in enumerate(gemma_commands):
41 | f.write(f'{cmd}\n')
42 |
43 | f.write('\nwait\n') # Wait for all commands to finish
44 |
45 | # Make the script executable
46 | import os
47 | os.chmod('run_parallel.sh', 0o755)
48 |
49 | print(f"Created run_parallel.sh with {len(gpt2_commands)} GPT-2 commands and {len(gemma_commands)} Gemma commands")
50 |
51 |
--------------------------------------------------------------------------------
/table/eval_table.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | import sys
4 | import pandas as pd
5 |
6 | # Add the parent directory to the sys path
7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
8 |
9 |
10 | import itertools
11 | import os
12 | from dictionary_learning.trainers.top_k import AutoEncoderTopK, TrainerTopK
13 | from nnsight import LanguageModel
14 | import torch as t
15 | from dictionary_learning import ActivationBuffer
16 | from dictionary_learning.training import trainSAE
17 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename, str2bool
18 | from dictionary_learning.trainers.switch import SwitchAutoEncoder, SwitchTrainer
19 | from modified_eval import evaluate
20 | import wandb
21 | import argparse
22 | import itertools
23 | from config import hf
24 |
25 | base_path = "/home/jengels/switch_sae/other_dictionaries"
26 |
27 | sae_types = ['switch', 'topk']
28 |
29 | # %%
30 |
31 | def get_model_metrics(model_name, layers_to_eval, layer_types, sae_types, info_to_filename, d_submodule, device="cuda:0"):
32 | model = LanguageModel(model_name, dispatch=True, device_map=device)
33 | info_to_metrics = {}
34 |
35 | for layer in layers_to_eval:
36 | for layer_type in layer_types:
37 | values = []
38 |
39 | if model_name == "google/gemma-2b":
40 | submodule = model.model.layers[layer]
41 | else:
42 | submodule = model.transformer.h[layer]
43 |
44 | if layer_type == "resid":
45 | submodule = submodule
46 | elif layer_type == "mlp":
47 | submodule = submodule.mlp
48 | elif layer_type == "attn":
49 | submodule = submodule.attn
50 |
51 | for sae_type in sae_types:
52 |
53 | data = hf_dataset_to_generator(hf, split="train")
54 | buffer = ActivationBuffer(
55 | data,
56 | model,
57 | submodule,
58 | d_submodule=d_submodule,
59 | n_ctxs=1e2,
60 | device="cpu",
61 | out_batch_size=8192,
62 | refresh_batch_size=1 if model_name == "google/gemma-2b" else 128,
63 | )
64 |
65 | filename = info_to_filename[(layer, sae_type, layer_type)]
66 |
67 | if model_name == "google/gemma-2b":
68 | model_path = f"{filename}/checkpoints/ae_{gemma_checkpoint_id}.pt"
69 | else:
70 | model_path = f"{filename}/ae.pt"
71 |
72 | if sae_type == "switch":
73 | ae = SwitchAutoEncoder.from_pretrained(
74 | model_path,
75 | k=64,
76 | experts=8,
77 | heaviside=False,
78 | device=device,
79 | )
80 | metrics = evaluate(ae,
81 | buffer,
82 | device=device,
83 | batch_size=1 if model_name == "google/gemma-2b" else 128,
84 | num_batches=32 if model_name == "google/gemma-2b" else 2)
85 | else:
86 | ae = AutoEncoderTopK.from_pretrained(
87 | model_path,
88 | k=64,
89 | device=device,
90 | )
91 | metrics = evaluate(ae,
92 | buffer,
93 | device=device,
94 | batch_size=1 if model_name == "google/gemma-2b" else 128,
95 | num_batches=32 if model_name == "google/gemma-2b" else 2)
96 |
97 | values.append(metrics)
98 | print(metrics)
99 |
100 | key = (layer, layer_type)
101 | info_to_metrics[key] = values
102 |
103 | del model
104 | return info_to_metrics
105 |
106 | # %%
107 |
108 |
109 | # Generate all combinations of parameters
110 | layers = [2, 4, 8, 10]
111 | types = ['resid', 'attn', 'mlp']
112 | devices = [f'cuda:{i}' for i in range(8)]
113 |
114 | # First set of commands for GPT-2
115 | gpt_info_to_filename = {}
116 | for layer, sae_type, layer_type in itertools.product(layers, sae_types, types):
117 |
118 | device_num = len(gpt_info_to_filename) % 8
119 | if sae_type == 'switch':
120 | filename = f"dict_class:.SwitchAutoEncoder'>_activation_dim:768_dict_size:24576_auxk_alpha:0.03125_decay_start:16000_steps:20000_seed:0_device:cuda:{device_num}_layer:{layer}_lm_name:openai-communitygpt2_wandb_name:SwitchAutoEncoder_k:64_experts:8_heaviside:False_lb_alpha:3"
121 | else:
122 | filename = f"dict_class:_k.AutoEncoderTopK'>_activation_dim:768_dict_size:24576_auxk_alpha:0.03125_decay_start:16000_steps:20000_seed:0_device:cuda:{device_num}_layer:{layer}_lm_name:openai-communitygpt2_wandb_name:AutoEncoderTopK_k:64"
123 |
124 | filename = os.path.join(base_path, filename)
125 |
126 |
127 | if not os.path.exists(filename):
128 | print(layer, sae_type, layer_type)
129 | continue
130 |
131 | gpt_info_to_filename[(layer, sae_type,layer_type)] = filename
132 |
133 |
134 | # Get GPT-2 metrics
135 | gpt_info_to_metrics = get_model_metrics(
136 | model_name="openai-community/gpt2",
137 | layers_to_eval=layers,
138 | layer_types=types,
139 | sae_types=sae_types,
140 | info_to_filename=gpt_info_to_filename,
141 | d_submodule=768,
142 | device="cuda:1"
143 | )
144 |
145 |
146 | # %%
147 | # Get Gemma metrics
148 | gemma_info_to_metrics = {}
149 | gemma_checkpoint_id = 2000
150 |
151 | # First create filename mapping for Gemma
152 | gemma_info_to_filename = {}
153 | for i, sae_type in enumerate(sae_types):
154 | layer = 12
155 | device = f"cuda:{i}"
156 | if sae_type == "switch":
157 | filename = f"{base_path}/dict_class:.SwitchAutoEncoder'>_activation_dim:2048_dict_size:65536_auxk_alpha:0.03125_decay_start:1600000_steps:2000000_seed:0_device:cuda:{i}_layer:12_lm_name:googlegemma-2b_wandb_name:SwitchAutoEncoder_k:64_experts:8_heaviside:False_lb_alpha:3"
158 | else:
159 | filename = f"{base_path}/dict_class:_k.AutoEncoderTopK'>_activation_dim:2048_dict_size:65536_auxk_alpha:0.03125_decay_start:1600000_steps:2000000_seed:0_device:cuda:{i}_layer:12_lm_name:googlegemma-2b_wandb_name:AutoEncoderTopK_k:64"
160 |
161 | gemma_info_to_filename[(layer, sae_type, "resid")] = filename
162 |
163 | gemma_info_to_metrics = get_model_metrics(
164 | model_name="google/gemma-2b",
165 | layers_to_eval=[12],
166 | layer_types=["resid"],
167 | sae_types=sae_types,
168 | info_to_filename=gemma_info_to_filename,
169 | d_submodule=2048,
170 | device="cuda:1"
171 | )
172 |
173 |
174 |
175 |
176 |
177 | # %%
178 |
179 |
180 | # %%
181 | rows = []
182 | # Add GPT-2 results
183 | for (layer, layer_type), values in gpt_info_to_metrics.items():
184 | row = {
185 | 'Model': 'GPT-2',
186 | 'Layer': layer,
187 | 'Type': layer_type,
188 | 'TopK FVE': f"{values[1]['frac_variance_explained']:.3f}",
189 | 'Switch FVE': f"{values[0]['frac_variance_explained']:.3f}",
190 | 'TopK FR': f"{values[1]['frac_recovered']:.3f}",
191 | 'Switch FR': f"{values[0]['frac_recovered']:.3f}"
192 | }
193 | rows.append(row)
194 |
195 | # Add Gemma results
196 | for (layer, layer_type), values in gemma_info_to_metrics.items():
197 | row = {
198 | 'Model': 'Gemma',
199 | 'Layer': layer,
200 | 'Type': layer_type,
201 | 'TopK FVE': f"{values[1]['frac_variance_explained']:.3f}",
202 | 'Switch FVE': f"{values[0]['frac_variance_explained']:.3f}",
203 | 'TopK FR': f"{values[1]['frac_recovered']:.3f}",
204 | 'Switch FR': f"{values[0]['frac_recovered']:.3f}"
205 | }
206 | rows.append(row)
207 |
208 | df = pd.DataFrame(rows)
209 | print("\nResults Table:")
210 | print(df.to_markdown(index=False))
211 |
212 | # %%
213 |
214 | df.to_csv("results.csv", index=False)
215 |
216 | # %%
217 |
218 | # Convert to LaTeX table
219 | latex_table = """\\begin{table}[h]
220 | \\centering
221 | \\begin{tabular}{lcccccccc}
222 | \\toprule
223 | Model & Layer & Type & TopK FVE & Switch FVE & TopK FR & Switch FR \\\\
224 | \\midrule"""
225 |
226 | for _, row in df.iterrows():
227 | latex_table += f"\n{row['Model']} & {row['Layer']} & {row['Type']} & {row['TopK FVE']} & {row['Switch FVE']} & {row['TopK FR']} & {row['Switch FR']} \\\\"
228 |
229 | latex_table += """
230 | \\bottomrule
231 | \\end{tabular}
232 | \\caption{Comparison of TopK and Switch autoencoders across different models, layers and component types. FVE = Fraction of Variance Explained, FR = Fraction Recovered.}
233 | \\label{tab:model_comparison}
234 | \\end{table}"""
235 |
236 | print("\nLaTeX Table:")
237 | print(latex_table)
238 |
--------------------------------------------------------------------------------
/table/modified_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for evaluating dictionaries on a model and dataset.
3 | """
4 |
5 | import os
6 | import sys
7 |
8 | # Add the parent directory to the sys path
9 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
10 |
11 |
12 | import torch as t
13 | from tqdm import tqdm
14 | from dictionary_learning.buffer import ActivationBuffer, NNsightActivationBuffer
15 | from nnsight import LanguageModel
16 | from dictionary_learning.config import DEBUG
17 |
18 |
19 | def loss_recovered(
20 | text, # a batch of text
21 | model: LanguageModel, # an nnsight LanguageModel
22 | submodule, # submodules of model
23 | dictionary, # dictionaries for submodules
24 | max_len=None, # max context length for loss recovered
25 | normalize_batch=False, # normalize batch before passing through dictionary
26 | io="out", # can be 'in', 'out', or 'in_and_out'
27 | tracer_args = {'use_cache': False, 'output_attentions': False}, # minimize cache during model trace.
28 | ):
29 | """
30 | How much of the model's loss is recovered by replacing the component output
31 | with the reconstruction by the autoencoder?
32 | """
33 |
34 | if max_len is None:
35 | invoker_args = {}
36 | else:
37 | invoker_args = {"truncation": True, "max_length": max_len }
38 |
39 | # unmodified logits
40 | with model.trace(text, invoker_args=invoker_args):
41 | logits_original = model.output.save()
42 | logits_original = logits_original.value
43 |
44 | # logits when replacing component activations with reconstruction by autoencoder
45 | with model.trace(text, **tracer_args, invoker_args=invoker_args):
46 | if io == 'in':
47 | x = submodule.input[0]
48 | if type(submodule.input.shape) == tuple: x = x[0]
49 | if normalize_batch:
50 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
51 | x = x * scale
52 | elif io == 'out':
53 | x = submodule.output
54 | if type(submodule.output.shape) == tuple: x = x[0]
55 | if normalize_batch:
56 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
57 | x = x * scale
58 | elif io == 'in_and_out':
59 | x = submodule.input[0]
60 | if type(submodule.input.shape) == tuple: x = x[0]
61 | print(f'x.shape: {x.shape}')
62 | if normalize_batch:
63 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
64 | x = x * scale
65 | else:
66 | raise ValueError(f"Invalid value for io: {io}")
67 | x = x.save()
68 |
69 | # pull this out so dictionary can be written without FakeTensor (top_k needs this)
70 | x_hat = dictionary(x.view(-1, x.shape[-1])).view(x.shape)
71 |
72 | # intervene with `x_hat`
73 | with model.trace(text, **tracer_args, invoker_args=invoker_args):
74 | if io == 'in':
75 | x = submodule.input[0]
76 | if normalize_batch:
77 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
78 | x_hat = x_hat / scale
79 | if type(submodule.input.shape) == tuple:
80 | submodule.input[0][:] = x_hat
81 | else:
82 | submodule.input = x_hat
83 | elif io == 'out':
84 | x = submodule.output
85 | if normalize_batch:
86 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
87 | x_hat = x_hat / scale
88 | if type(submodule.output.shape) == tuple:
89 | submodule.output = (x_hat,)
90 | else:
91 | submodule.output = x_hat
92 | elif io == 'in_and_out':
93 | x = submodule.input[0]
94 | if normalize_batch:
95 | scale = (dictionary.activation_dim ** 0.5) / x.norm(dim=-1).mean()
96 | x_hat = x_hat / scale
97 | submodule.output = x_hat
98 | else:
99 | raise ValueError(f"Invalid value for io: {io}")
100 |
101 | logits_reconstructed = model.output.save()
102 | logits_reconstructed = logits_reconstructed.value
103 |
104 | # logits when replacing component activations with zeros
105 | with model.trace(text, **tracer_args, invoker_args=invoker_args):
106 | if io == 'in':
107 | x = submodule.input[0]
108 | if type(submodule.input.shape) == tuple:
109 | submodule.input[0][:] = t.zeros_like(x[0])
110 | else:
111 | submodule.input = t.zeros_like(x)
112 | elif io in ['out', 'in_and_out']:
113 | x = submodule.output
114 | if type(submodule.output.shape) == tuple:
115 | submodule.output[0][:] = t.zeros_like(x[0])
116 | else:
117 | submodule.output = t.zeros_like(x)
118 | else:
119 | raise ValueError(f"Invalid value for io: {io}")
120 |
121 | input = model.input.save()
122 | logits_zero = model.output.save()
123 | logits_zero = logits_zero.value
124 |
125 | # get everything into the right format
126 | try:
127 | logits_original = logits_original.logits
128 | logits_reconstructed = logits_reconstructed.logits
129 | logits_zero = logits_zero.logits
130 | except:
131 | pass
132 |
133 | if isinstance(text, t.Tensor):
134 | tokens = text
135 | else:
136 | try:
137 | tokens = input[1]['input_ids']
138 | except:
139 | tokens = input[1]['input']
140 |
141 | # compute losses
142 | losses = []
143 | if hasattr(model, 'tokenizer') and model.tokenizer is not None:
144 | loss_kwargs = {'ignore_index': model.tokenizer.pad_token_id}
145 | else:
146 | loss_kwargs = {}
147 | for logits in [logits_original, logits_reconstructed, logits_zero]:
148 | loss = t.nn.CrossEntropyLoss(**loss_kwargs)(
149 | logits[:, :-1, :].reshape(-1, logits.shape[-1]), tokens[:, 1:].reshape(-1)
150 | )
151 | losses.append(loss)
152 |
153 | return tuple(losses)
154 |
155 |
156 | def evaluate(
157 | dictionary, # a dictionary
158 | activations, # a generator of activations; if an ActivationBuffer, also compute loss recovered
159 | max_len=128, # max context length for loss recovered
160 | batch_size=8, # batch size for loss recovered
161 | num_batches=100, # number of batches to evaluate,
162 | io="out", # can be 'in', 'out', or 'in_and_out'
163 | normalize_batch=False, # normalize batch before passing through dictionary
164 | tracer_args={'use_cache': False, 'output_attentions': False}, # minimize cache during model trace.
165 | device="cpu",
166 | ):
167 | with t.no_grad():
168 | metrics_lists = {
169 | "l2_loss": [], "l1_loss": [], "mse_loss": [], "l0": [],
170 | "frac_alive": [], "frac_variance_explained": [], "cossim": [],
171 | "l2_ratio": [], "frac_recovered": []
172 | }
173 |
174 | for _ in tqdm(range(num_batches)):
175 | try:
176 | x = next(activations).to(device)
177 | if normalize_batch:
178 | x = x / x.norm(dim=-1).mean() * (dictionary.activation_dim ** 0.5)
179 |
180 | except StopIteration:
181 | raise StopIteration(
182 | "Not enough activations in buffer. Pass a buffer with a smaller batch size or more data."
183 | )
184 |
185 | x_hat, f = dictionary(x, output_features=True)
186 | l2_loss = t.linalg.norm(x - x_hat, dim=-1).mean()
187 | mse_loss = (x - x_hat).pow(2).sum(dim=-1).mean()
188 | l1_loss = f.norm(p=1, dim=-1).mean()
189 | l0 = (f != 0).float().sum(dim=-1).mean()
190 | frac_alive = t.flatten(f, start_dim=0, end_dim=1).any(dim=0).sum() / dictionary.dict_size
191 |
192 | # cosine similarity between x and x_hat
193 | x_normed = x / t.linalg.norm(x, dim=-1, keepdim=True)
194 | x_hat_normed = x_hat / t.linalg.norm(x_hat, dim=-1, keepdim=True)
195 | cossim = (x_normed * x_hat_normed).sum(dim=-1).mean()
196 |
197 | # l2 ratio
198 | l2_ratio = (t.linalg.norm(x_hat, dim=-1) / t.linalg.norm(x, dim=-1)).mean()
199 |
200 | #compute variance explained
201 | total_variance = t.var(x, dim=0).sum()
202 | residual_variance = t.var(x - x_hat, dim=0).sum()
203 | frac_variance_explained = (1 - residual_variance / total_variance)
204 |
205 | # compute loss recovered
206 | loss_original, loss_reconstructed, loss_zero = loss_recovered(
207 | activations.text_batch(batch_size=batch_size),
208 | activations.model,
209 | activations.submodule,
210 | dictionary,
211 | max_len=max_len,
212 | normalize_batch=normalize_batch,
213 | io=io,
214 | tracer_args=tracer_args
215 | )
216 | frac_recovered = (loss_reconstructed - loss_zero) / (loss_original - loss_zero)
217 |
218 | metrics_lists["l2_loss"].append(l2_loss.item())
219 | metrics_lists["l1_loss"].append(l1_loss.item())
220 | metrics_lists["mse_loss"].append(mse_loss.item())
221 | metrics_lists["l0"].append(l0.item())
222 | metrics_lists["frac_alive"].append(frac_alive.item())
223 | metrics_lists["frac_variance_explained"].append(frac_variance_explained.item())
224 | metrics_lists["cossim"].append(cossim.item())
225 | metrics_lists["l2_ratio"].append(l2_ratio.item())
226 | metrics_lists["frac_recovered"].append(frac_recovered.item())
227 |
228 | out = {k: sum(v)/len(v) for k, v in metrics_lists.items()}
229 |
230 | return out
--------------------------------------------------------------------------------
/table/run_parallel.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # GPT-2 training commands
4 | python3 table/train_switch_table.py --device cuda:0 --layer 2 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type switch --steps 20000 &
5 | python3 table/train_switch_table.py --device cuda:1 --layer 2 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type switch --steps 20000 &
6 | python3 table/train_switch_table.py --device cuda:2 --layer 2 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type switch --steps 20000 &
7 | python3 table/train_switch_table.py --device cuda:3 --layer 2 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type topk --steps 20000 &
8 | python3 table/train_switch_table.py --device cuda:4 --layer 2 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type topk --steps 20000 &
9 | python3 table/train_switch_table.py --device cuda:5 --layer 2 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type topk --steps 20000 &
10 | python3 table/train_switch_table.py --device cuda:6 --layer 4 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type switch --steps 20000 &
11 | python3 table/train_switch_table.py --device cuda:7 --layer 4 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type switch --steps 20000 &
12 |
13 | wait
14 |
15 | python3 table/train_switch_table.py --device cuda:0 --layer 4 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type switch --steps 20000 &
16 | python3 table/train_switch_table.py --device cuda:1 --layer 4 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type topk --steps 20000 &
17 | python3 table/train_switch_table.py --device cuda:2 --layer 4 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type topk --steps 20000 &
18 | python3 table/train_switch_table.py --device cuda:3 --layer 4 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type topk --steps 20000 &
19 | python3 table/train_switch_table.py --device cuda:4 --layer 8 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type switch --steps 20000 &
20 | python3 table/train_switch_table.py --device cuda:5 --layer 8 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type switch --steps 20000 &
21 | python3 table/train_switch_table.py --device cuda:6 --layer 8 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type switch --steps 20000 &
22 | python3 table/train_switch_table.py --device cuda:7 --layer 8 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type topk --steps 20000 &
23 |
24 | wait
25 |
26 | python3 table/train_switch_table.py --device cuda:0 --layer 8 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type topk --steps 20000 &
27 | python3 table/train_switch_table.py --device cuda:1 --layer 8 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type topk --steps 20000 &
28 | python3 table/train_switch_table.py --device cuda:2 --layer 10 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type switch --steps 20000 &
29 | python3 table/train_switch_table.py --device cuda:3 --layer 10 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type switch --steps 20000 &
30 | python3 table/train_switch_table.py --device cuda:4 --layer 10 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type switch --steps 20000 &
31 | python3 table/train_switch_table.py --device cuda:5 --layer 10 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type topk --steps 20000 &
32 | python3 table/train_switch_table.py --device cuda:6 --layer 10 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type topk --steps 20000 &
33 | python3 table/train_switch_table.py --device cuda:7 --layer 10 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type topk --steps 20000 &
34 |
35 | wait
36 |
37 | python3 table/train_switch_table.py --device cuda:0 --layer 12 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type switch --steps 20000 &
38 | python3 table/train_switch_table.py --device cuda:1 --layer 12 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type switch --steps 20000 &
39 | python3 table/train_switch_table.py --device cuda:2 --layer 12 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type switch --steps 20000 &
40 | python3 table/train_switch_table.py --device cuda:3 --layer 12 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type resid --sae_type topk --steps 20000 &
41 | python3 table/train_switch_table.py --device cuda:4 --layer 12 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type attn --sae_type topk --steps 20000 &
42 | python3 table/train_switch_table.py --device cuda:5 --layer 12 --lm openai-community/gpt2 --ks 64 --activation_dim 768 --dict_ratio 32 --num_experts 8 --type mlp --sae_type topk --steps 20000 &
43 |
44 | wait
45 |
46 | # Gemma-2B training commands
47 | python3 table/train_switch_table.py --device cuda:0 --layer 12 --lm google/gemma-2b --ks 64 --activation_dim 2048 --dict_ratio 32 --num_experts 8 --type resid --sae_type switch --steps 2000000 &
48 | python3 table/train_switch_table.py --device cuda:1 --layer 12 --lm google/gemma-2b --ks 64 --activation_dim 2048 --dict_ratio 32 --num_experts 8 --type resid --sae_type topk --steps 2000000 &
49 |
50 | wait
51 |
--------------------------------------------------------------------------------
/table/train_switch_table.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | import sys
4 |
5 | # Add the parent directory to the sys path
6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7 |
8 |
9 | from dictionary_learning.trainers.top_k import AutoEncoderTopK, TrainerTopK
10 | from nnsight import LanguageModel
11 | import torch as t
12 | from dictionary_learning import ActivationBuffer
13 | from dictionary_learning.training import trainSAE
14 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename, str2bool
15 | from dictionary_learning.trainers.switch import SwitchAutoEncoder, SwitchTrainer
16 | from dictionary_learning.evaluation import evaluate
17 | import wandb
18 | import argparse
19 | import itertools
20 | from config import hf
21 |
22 |
23 | # %%
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument("--device", type=str, default="cuda:1", help="Device to run on")
26 | parser.add_argument("--lm", type=str, help="Language model to use")
27 | parser.add_argument(
28 | "--layer", type=int, default=12, help="Layer to extract activations from"
29 | )
30 | parser.add_argument(
31 | "--type", type=str, default="resid", choices=["resid", "mlp", "attn"]
32 | )
33 | parser.add_argument(
34 | "--sae_type", type=str, default="switch", choices=["switch", "topk"]
35 | )
36 | parser.add_argument("--ks", nargs="+", type=int, default=[64], help="List of k values")
37 | parser.add_argument(
38 | "--activation_dim", type=int, default=2048, help="Dimension of activations"
39 | )
40 | parser.add_argument("--dict_ratio", type=int, default=32, help="Dictionary size ratio")
41 | parser.add_argument(
42 | "--num_experts", nargs="+", type=int, default=[8], help="List of number of experts"
43 | )
44 | parser.add_argument(
45 | "--steps", type=int, default=10000, help="Number of steps to train for"
46 | )
47 |
48 | args = parser.parse_args()
49 |
50 | device = args.device
51 | layer = args.layer
52 | lm = args.lm
53 | print(lm)
54 | ks = args.ks
55 | activation_dim = args.activation_dim
56 | dict_ratio = args.dict_ratio
57 | num_experts = args.num_experts
58 | lb_alphas = [3]
59 | heavisides = [False]
60 | n_ctxs = 3e4
61 | batch_size = 8192
62 | steps = args.steps
63 |
64 | # %%
65 |
66 | model = LanguageModel(lm, dispatch=True, device_map=device)
67 |
68 | # %%
69 |
70 | if lm == "openai-community/gpt2":
71 |
72 | if args.type == "resid":
73 | submodule = model.transformer.h[layer]
74 | elif args.type == "mlp":
75 | submodule = model.transformer.h[layer].mlp
76 | elif args.type == "attn":
77 | submodule = model.transformer.h[layer].attn
78 |
79 | else:
80 |
81 | if args.type == "resid":
82 | submodule = model.model.layers[layer]
83 | elif args.type == "mlp":
84 | submodule = model.model.layers[layer].mlp
85 | elif args.type == "attn":
86 | submodule = model.model.layers[layer].self_attn
87 |
88 | # %%
89 |
90 |
91 | data = hf_dataset_to_generator(hf)
92 | buffer = ActivationBuffer(
93 | data,
94 | model,
95 | submodule,
96 | d_submodule=activation_dim,
97 | n_ctxs=n_ctxs,
98 | device="cpu",
99 | out_batch_size=batch_size,
100 | refresh_batch_size=512 if lm == "openai-community/gpt2" else 64,
101 | )
102 |
103 | if args.sae_type == "switch":
104 |
105 | base_trainer_config = {
106 | "trainer": SwitchTrainer,
107 | "dict_class": SwitchAutoEncoder,
108 | "activation_dim": activation_dim,
109 | "dict_size": dict_ratio * activation_dim,
110 | "auxk_alpha": 1 / 32,
111 | "decay_start": int(steps * 0.8),
112 | "steps": steps,
113 | "seed": 0,
114 | "device": device,
115 | "layer": layer,
116 | "lm_name": lm,
117 | "wandb_name": "SwitchAutoEncoder",
118 | }
119 |
120 | trainer_configs = [
121 | (
122 | base_trainer_config
123 | | {
124 | "k": combo[0],
125 | "experts": combo[1],
126 | "heaviside": combo[2],
127 | "lb_alpha": combo[3],
128 | }
129 | )
130 | for combo in itertools.product(ks, num_experts, heavisides, lb_alphas)
131 | ]
132 | else:
133 | base_trainer_config = {
134 | "trainer": TrainerTopK,
135 | "dict_class": AutoEncoderTopK,
136 | "activation_dim": activation_dim,
137 | "dict_size": args.dict_ratio * activation_dim,
138 | "auxk_alpha": 1 / 32,
139 | "decay_start": int(steps * 0.8),
140 | "steps": steps,
141 | "seed": 0,
142 | "device": device,
143 | "layer": layer,
144 | "lm_name": lm,
145 | "wandb_name": "AutoEncoderTopK",
146 | }
147 |
148 | trainer_configs = [(base_trainer_config | {"k": k}) for k in args.ks]
149 |
150 |
151 | wandb.init(
152 | entity="josh_engels",
153 | project="Switch",
154 | config={
155 | f'{trainer_config["wandb_name"]}-{i}': trainer_config
156 | for i, trainer_config in enumerate(trainer_configs)
157 | },
158 | )
159 |
160 | trainSAE(
161 | buffer,
162 | trainer_configs=trainer_configs,
163 | save_dir="dictionaries",
164 | log_steps=1,
165 | steps=steps,
166 | save_steps=1000,
167 | )
168 |
169 | print("Training finished. Evaluating SAE...", flush=True)
170 | for i, trainer_config in enumerate(trainer_configs):
171 | if args.sae_type == "switch":
172 | ae = SwitchAutoEncoder.from_pretrained(
173 | f"dictionaries/{cfg_filename(trainer_config)}/ae.pt",
174 | k=trainer_config["k"],
175 | experts=trainer_config["experts"],
176 | heaviside=trainer_config["heaviside"],
177 | device=device,
178 | )
179 | metrics = evaluate(ae, buffer, device=device)
180 | log = {}
181 | log.update(
182 | {f'{trainer_config["wandb_name"]}-{i}/{k}': v for k, v in metrics.items()}
183 | )
184 | wandb.log(log, step=steps + 1)
185 | else:
186 | ae = AutoEncoderTopK.from_pretrained(
187 | f"dictionaries/{cfg_filename(trainer_config)}/ae.pt",
188 | k=trainer_config["k"],
189 | device=device,
190 | )
191 | metrics = evaluate(ae, buffer, device=device)
192 | log = {}
193 | log.update(
194 | {f'{trainer_config["wandb_name"]}-{i}/{k}': v for k, v in metrics.items()}
195 | )
196 | wandb.log(log, step=steps + 1)
197 | wandb.finish()
198 |
--------------------------------------------------------------------------------
/train-gated.py:
--------------------------------------------------------------------------------
1 | from nnsight import LanguageModel
2 | import torch as t
3 | from dictionary_learning import ActivationBuffer
4 | from dictionary_learning.training import trainSAE
5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename
6 | from dictionary_learning.dictionary import GatedAutoEncoder
7 | from dictionary_learning.trainers.gdm import GatedSAETrainer
8 | from dictionary_learning.evaluation import evaluate
9 | import wandb
10 | import argparse
11 | from config import lm, activation_dim, layer, hf, steps, n_ctxs
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--gpu", required=True)
15 | parser.add_argument('--lr', type=float, default=1e-3) ## 3e-4, 5e-5
16 | parser.add_argument('--dict_ratio', type=int, default=32)
17 | parser.add_argument("--l1_penalties", nargs="+", type=float, required=True)
18 | args = parser.parse_args()
19 |
20 | device = f'cuda:{args.gpu}'
21 | model = LanguageModel(lm, dispatch=True, device_map=device)
22 | submodule = model.transformer.h[layer]
23 | data = hf_dataset_to_generator(hf)
24 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device)
25 |
26 | base_trainer_config = {
27 | 'trainer' : GatedSAETrainer,
28 | 'dict_class' : GatedAutoEncoder,
29 | 'activation_dim' : activation_dim,
30 | 'dict_size' : args.dict_ratio * activation_dim,
31 | 'lr' : args.lr,
32 | 'warmup_steps' : 1000,
33 | 'resample_steps' : None,
34 | 'seed' : 0,
35 | 'device' : device,
36 | 'layer' : layer,
37 | 'lm_name' : lm,
38 | 'wandb_name' : 'GatedSAETrainer'
39 | }
40 |
41 | trainer_configs = [(base_trainer_config | {'l1_penalty': l1_penalty}) for l1_penalty in args.l1_penalties]
42 |
43 | wandb.init(entity="amudide", project="Gated-BigLR", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)})
44 |
45 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps)
46 |
47 | print("Training finished. Evaluating SAE...", flush=True)
48 | for i, trainer_config in enumerate(trainer_configs):
49 | ae = GatedAutoEncoder.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt', device=device)
50 | metrics = evaluate(ae, buffer, device=device)
51 | log = {}
52 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()})
53 | wandb.log(log, step=steps+1)
54 | wandb.finish()
--------------------------------------------------------------------------------
/train-jump.py:
--------------------------------------------------------------------------------
1 | from nnsight import LanguageModel
2 | import torch as t
3 | from dictionary_learning import ActivationBuffer
4 | from dictionary_learning.training import trainSAE
5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename
6 | from dictionary_learning.dictionary import JumpReluAutoEncoder
7 | from dictionary_learning.trainers.jump_relu import JumpReluTrainer
8 | from dictionary_learning.evaluation import evaluate
9 | import wandb
10 | import argparse
11 | from config import lm, activation_dim, layer, hf, steps, n_ctxs
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--gpu", required=True)
15 | parser.add_argument('--lr', type=float, default=7e-5)
16 | parser.add_argument('--dict_ratio', type=int, default=32)
17 | parser.add_argument("--l0_penalties", nargs="+", type=float, required=True)
18 | args = parser.parse_args()
19 |
20 | device = f'cuda:{args.gpu}'
21 | model = LanguageModel(lm, dispatch=True, device_map=device)
22 | submodule = model.transformer.h[layer]
23 | data = hf_dataset_to_generator(hf)
24 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device)
25 |
26 | base_trainer_config = {
27 | 'trainer' : JumpReluTrainer,
28 | 'dict_class' : JumpReluAutoEncoder,
29 | 'activation_dim' : activation_dim,
30 | 'dict_size' : args.dict_ratio * activation_dim,
31 | 'lr' : args.lr,
32 | 'warmup_steps' : 1000,
33 | 'seed' : 0,
34 | 'device' : device,
35 | 'layer' : layer,
36 | 'lm_name' : lm,
37 | 'wandb_name' : 'JumpReluTrainer'
38 | }
39 |
40 | trainer_configs = [(base_trainer_config | {'l0_penalty': l0_penalty}) for l0_penalty in args.l0_penalties]
41 |
42 | wandb.init(entity="amudide", project="Jump", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)})
43 |
44 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps)
45 |
46 | print("Training finished. Evaluating SAE...", flush=True)
47 | for i, trainer_config in enumerate(trainer_configs):
48 | ae = JumpReluAutoEncoder.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt', device=device)
49 | metrics = evaluate(ae, buffer, device=device)
50 | log = {}
51 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()})
52 | wandb.log(log, step=steps+1)
53 | wandb.finish()
--------------------------------------------------------------------------------
/train-moe.py:
--------------------------------------------------------------------------------
1 | from nnsight import LanguageModel
2 | import torch as t
3 | from dictionary_learning import ActivationBuffer
4 | from dictionary_learning.training import trainSAE
5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename, str2bool
6 | from dictionary_learning.trainers.moe import MoEAutoEncoder, MoETrainer
7 | from dictionary_learning.evaluation import evaluate
8 | import wandb
9 | import argparse
10 | import itertools
11 | from config import lm, activation_dim, layer, hf, steps, n_ctxs
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--gpu", required=True)
15 | parser.add_argument('--dict_ratio', type=int, default=32)
16 | parser.add_argument("--ks", nargs="+", type=int, required=True)
17 | parser.add_argument("--num_experts", nargs="+", type=int, required=True)
18 | parser.add_argument("--es", nargs="+", type=int, required=True)
19 | parser.add_argument("--heavisides", nargs="+", type=str2bool, required=True)
20 | args = parser.parse_args()
21 |
22 | device = f'cuda:{args.gpu}'
23 | model = LanguageModel(lm, dispatch=True, device_map=device)
24 | submodule = model.transformer.h[layer]
25 | data = hf_dataset_to_generator(hf)
26 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device)
27 |
28 | base_trainer_config = {
29 | 'trainer' : MoETrainer,
30 | 'dict_class' : MoEAutoEncoder,
31 | 'activation_dim' : activation_dim,
32 | 'dict_size' : args.dict_ratio * activation_dim,
33 | 'auxk_alpha' : 1/32,
34 | 'decay_start' : int(steps * 0.8),
35 | 'steps' : steps,
36 | 'seed' : 0,
37 | 'device' : device,
38 | 'layer' : layer,
39 | 'lm_name' : lm,
40 | 'wandb_name' : 'MoEAutoEncoder'
41 | }
42 |
43 | trainer_configs = [(base_trainer_config | {'k': combo[0], 'experts': combo[1], 'e': combo[2], 'heaviside': combo[3]}) for combo in itertools.product(args.ks, args.num_experts, args.es, args.heavisides)]
44 |
45 | wandb.init(entity="amudide", project="MoE", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)})
46 |
47 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps)
48 |
49 | print("Training finished. Evaluating SAE...", flush=True)
50 | for i, trainer_config in enumerate(trainer_configs):
51 | ae = MoEAutoEncoder.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt',
52 | k = trainer_config['k'], experts = trainer_config['experts'], e = trainer_config['e'], heaviside = trainer_config['heaviside'], device=device)
53 | metrics = evaluate(ae, buffer, device=device)
54 | log = {}
55 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()})
56 | wandb.log(log, step=steps+1)
57 | wandb.finish()
--------------------------------------------------------------------------------
/train-relu.py:
--------------------------------------------------------------------------------
1 | from nnsight import LanguageModel
2 | import torch as t
3 | from dictionary_learning import ActivationBuffer
4 | from dictionary_learning.training import trainSAE
5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename
6 | from dictionary_learning.dictionary import AutoEncoderNew
7 | from dictionary_learning.trainers.standard_new import StandardTrainerNew
8 | from dictionary_learning.evaluation import evaluate
9 | import wandb
10 | import argparse
11 | from config import lm, activation_dim, layer, hf, steps, n_ctxs
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--gpu", required=True)
15 | parser.add_argument('--lr', type=float, default=5e-5) ## 3e-4
16 | parser.add_argument('--dict_ratio', type=int, default=32)
17 | parser.add_argument("--l1_penalties", nargs="+", type=float, required=True)
18 | args = parser.parse_args()
19 |
20 | device = f'cuda:{args.gpu}'
21 | model = LanguageModel(lm, dispatch=True, device_map=device)
22 | submodule = model.transformer.h[layer]
23 | data = hf_dataset_to_generator(hf)
24 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device)
25 |
26 | base_trainer_config = {
27 | 'trainer' : StandardTrainerNew,
28 | 'dict_class' : AutoEncoderNew,
29 | 'activation_dim' : activation_dim,
30 | 'dict_size' : args.dict_ratio * activation_dim,
31 | 'lr' : args.lr,
32 | 'lambda_warm_steps' : int(steps * 0.05),
33 | 'decay_start' : int(steps * 0.8),
34 | 'steps' : steps,
35 | 'seed' : 0,
36 | 'device' : device,
37 | 'wandb_name' : 'StandardTrainerNew_Anthropic'
38 | }
39 |
40 | trainer_configs = [(base_trainer_config | {'l1_penalty': l1_penalty}) for l1_penalty in args.l1_penalties]
41 |
42 | wandb.init(entity="amudide", project="ReLU", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)})
43 |
44 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps)
45 |
46 | print("Training finished. Evaluating SAE...", flush=True)
47 | for i, trainer_config in enumerate(trainer_configs):
48 | ae = AutoEncoderNew.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt', device=device)
49 | metrics = evaluate(ae, buffer, device=device)
50 | log = {}
51 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()})
52 | wandb.log(log, step=steps+1)
53 | wandb.finish()
--------------------------------------------------------------------------------
/train-switch-1on.py:
--------------------------------------------------------------------------------
1 | from nnsight import LanguageModel
2 | import torch as t
3 | from dictionary_learning import ActivationBuffer
4 | from dictionary_learning.training import trainSAE
5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename, str2bool
6 | from dictionary_learning.trainers.switch1on import SwitchAutoEncoder, SwitchTrainer
7 | from dictionary_learning.evaluation import evaluate
8 | import wandb
9 | import argparse
10 | import itertools
11 | from config import lm, activation_dim, layer, hf, steps, n_ctxs
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--gpu", required=True)
15 | parser.add_argument('--dict_ratio', type=int, default=32)
16 | parser.add_argument("--ks", nargs="+", type=int, required=True)
17 | parser.add_argument("--num_experts", nargs="+", type=int, required=True)
18 | parser.add_argument("--lb_alphas", nargs="+", type=float, default=[3.0])
19 | parser.add_argument("--heavisides", nargs="+", type=str2bool, default=[False])
20 | args = parser.parse_args()
21 |
22 | device = f'cuda:{args.gpu}'
23 | model = LanguageModel(lm, dispatch=True, device_map=device)
24 | submodule = model.transformer.h[layer]
25 | data = hf_dataset_to_generator(hf)
26 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device)
27 |
28 | base_trainer_config = {
29 | 'trainer' : SwitchTrainer,
30 | 'dict_class' : SwitchAutoEncoder,
31 | 'activation_dim' : activation_dim,
32 | 'dict_size' : args.dict_ratio * activation_dim,
33 | 'auxk_alpha' : 1/32,
34 | 'decay_start' : int(steps * 0.8),
35 | 'steps' : steps,
36 | 'seed' : 0,
37 | 'device' : device,
38 | 'layer' : layer,
39 | 'lm_name' : lm,
40 | 'wandb_name' : 'SwitchAutoEncoder'
41 | }
42 |
43 | trainer_configs = [(base_trainer_config | {'k': combo[0], 'experts': combo[1], 'heaviside': combo[2], 'lb_alpha': combo[3]}) for combo in itertools.product(args.ks, args.num_experts, args.heavisides, args.lb_alphas)]
44 |
45 | wandb.init(entity="amudide", project="Switch (1 Always On)", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)})
46 |
47 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps)
48 |
49 | print("Training finished. Evaluating SAE...", flush=True)
50 | for i, trainer_config in enumerate(trainer_configs):
51 | ae = SwitchAutoEncoder.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt',
52 | k = trainer_config['k'], experts = trainer_config['experts'], heaviside = trainer_config['heaviside'], device=device)
53 | metrics = evaluate(ae, buffer, device=device)
54 | log = {}
55 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()})
56 | wandb.log(log, step=steps+1)
57 | wandb.finish()
--------------------------------------------------------------------------------
/train-switch-flop.py:
--------------------------------------------------------------------------------
1 | from nnsight import LanguageModel
2 | import torch as t
3 | from dictionary_learning import ActivationBuffer
4 | from dictionary_learning.training import trainSAE
5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename, str2bool
6 | from dictionary_learning.trainers.switch import SwitchAutoEncoder, SwitchTrainer
7 | from dictionary_learning.evaluation import evaluate
8 | import wandb
9 | import argparse
10 | import itertools
11 | from config import lm, activation_dim, layer, hf, steps, n_ctxs
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--gpu", required=True)
15 | parser.add_argument('--dict_ratio', type=int, default=32)
16 | parser.add_argument("--ks", nargs="+", type=int, required=True)
17 | parser.add_argument("--num_experts", nargs="+", type=int, required=True)
18 | parser.add_argument("--lb_alphas", nargs="+", type=float, default=[3.0])
19 | parser.add_argument("--heavisides", nargs="+", type=str2bool, default=[False])
20 | args = parser.parse_args()
21 |
22 | device = f'cuda:{args.gpu}'
23 | model = LanguageModel(lm, dispatch=True, device_map=device)
24 | submodule = model.transformer.h[layer]
25 | data = hf_dataset_to_generator(hf)
26 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, out_batch_size=2048, n_ctxs=n_ctxs//2, device=device) ## 1/2 context, 1/4 batch size to save memory
27 |
28 | base_trainer_config = {
29 | 'trainer' : SwitchTrainer,
30 | 'dict_class' : SwitchAutoEncoder,
31 | 'activation_dim' : activation_dim,
32 | 'auxk_alpha' : 1/32,
33 | 'decay_start' : int(steps * 0.8),
34 | 'steps' : steps,
35 | 'seed' : 0,
36 | 'device' : device,
37 | 'layer' : layer,
38 | 'lm_name' : lm,
39 | 'wandb_name' : 'SwitchAutoEncoder'
40 | }
41 |
42 | trainer_configs = [(base_trainer_config | {'k': combo[0], 'experts': combo[1], 'dict_size' : args.dict_ratio * activation_dim * combo[1], 'heaviside': combo[2], 'lb_alpha': combo[3]}) for combo in itertools.product(args.ks, args.num_experts, args.heavisides, args.lb_alphas)]
43 |
44 | wandb.init(entity="amudide", project="Switch (FLOP Matched, LB)", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)})
45 |
46 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps)
47 |
48 | print("Training finished. Evaluating SAE...", flush=True)
49 | for i, trainer_config in enumerate(trainer_configs):
50 | ae = SwitchAutoEncoder.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt',
51 | k = trainer_config['k'], experts = trainer_config['experts'], heaviside = trainer_config['heaviside'], device=device)
52 | metrics = evaluate(ae, buffer, device=device)
53 | log = {}
54 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()})
55 | wandb.log(log, step=steps+1)
56 | wandb.finish()
--------------------------------------------------------------------------------
/train-switch.py:
--------------------------------------------------------------------------------
1 | from nnsight import LanguageModel
2 | import torch as t
3 | from dictionary_learning import ActivationBuffer
4 | from dictionary_learning.training import trainSAE
5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename, str2bool
6 | from dictionary_learning.trainers.switch import SwitchAutoEncoder, SwitchTrainer
7 | from dictionary_learning.evaluation import evaluate
8 | import wandb
9 | import argparse
10 | import itertools
11 | from config import lm, activation_dim, layer, hf, steps, n_ctxs
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--gpu", required=True)
15 | parser.add_argument('--dict_ratio', type=int, default=32)
16 | parser.add_argument("--ks", nargs="+", type=int, required=True)
17 | parser.add_argument("--num_experts", nargs="+", type=int, required=True)
18 | parser.add_argument("--lb_alphas", nargs="+", type=float, default=[3.0])
19 | parser.add_argument("--heavisides", nargs="+", type=str2bool, default=[False])
20 | args = parser.parse_args()
21 |
22 | device = f'cuda:{args.gpu}'
23 | model = LanguageModel(lm, dispatch=True, device_map=device)
24 | submodule = model.transformer.h[layer]
25 | data = hf_dataset_to_generator(hf)
26 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device)
27 |
28 | base_trainer_config = {
29 | 'trainer' : SwitchTrainer,
30 | 'dict_class' : SwitchAutoEncoder,
31 | 'activation_dim' : activation_dim,
32 | 'dict_size' : args.dict_ratio * activation_dim,
33 | 'auxk_alpha' : 1/32,
34 | 'decay_start' : int(steps * 0.8),
35 | 'steps' : steps,
36 | 'seed' : 0,
37 | 'device' : device,
38 | 'layer' : layer,
39 | 'lm_name' : lm,
40 | 'wandb_name' : 'SwitchAutoEncoder'
41 | }
42 |
43 | trainer_configs = [(base_trainer_config | {'k': combo[0], 'experts': combo[1], 'heaviside': combo[2], 'lb_alpha': combo[3]}) for combo in itertools.product(args.ks, args.num_experts, args.heavisides, args.lb_alphas)]
44 |
45 | wandb.init(entity="amudide", project="Switch (LB)", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)})
46 |
47 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps)
48 |
49 | print("Training finished. Evaluating SAE...", flush=True)
50 | for i, trainer_config in enumerate(trainer_configs):
51 | ae = SwitchAutoEncoder.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt',
52 | k = trainer_config['k'], experts = trainer_config['experts'], heaviside = trainer_config['heaviside'], device=device)
53 | metrics = evaluate(ae, buffer, device=device)
54 | log = {}
55 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()})
56 | wandb.log(log, step=steps+1)
57 | wandb.finish()
--------------------------------------------------------------------------------
/train-topk.py:
--------------------------------------------------------------------------------
1 | from nnsight import LanguageModel
2 | import torch as t
3 | from dictionary_learning import ActivationBuffer
4 | from dictionary_learning.training import trainSAE
5 | from dictionary_learning.utils import hf_dataset_to_generator, cfg_filename
6 | from dictionary_learning.trainers.top_k import AutoEncoderTopK, TrainerTopK
7 | from dictionary_learning.evaluation import evaluate
8 | import wandb
9 | import argparse
10 | from config import lm, activation_dim, layer, hf, steps, n_ctxs
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--gpu", required=True)
14 | parser.add_argument('--dict_ratio', type=int, default=32)
15 | parser.add_argument("--ks", nargs="+", type=int, required=True)
16 | args = parser.parse_args()
17 |
18 | device = f'cuda:{args.gpu}'
19 | model = LanguageModel(lm, dispatch=True, device_map=device)
20 | submodule = model.transformer.h[layer]
21 | data = hf_dataset_to_generator(hf)
22 | buffer = ActivationBuffer(data, model, submodule, d_submodule=activation_dim, n_ctxs=n_ctxs, device=device)
23 |
24 | base_trainer_config = {
25 | 'trainer' : TrainerTopK,
26 | 'dict_class' : AutoEncoderTopK,
27 | 'activation_dim' : activation_dim,
28 | 'dict_size' : args.dict_ratio * activation_dim,
29 | 'auxk_alpha' : 1/32,
30 | 'decay_start' : int(steps * 0.8),
31 | 'steps' : steps,
32 | 'seed' : 0,
33 | 'device' : device,
34 | 'layer' : layer,
35 | 'lm_name' : lm,
36 | 'wandb_name' : 'AutoEncoderTopK'
37 | }
38 |
39 | trainer_configs = [(base_trainer_config | {'k': k}) for k in args.ks]
40 |
41 | wandb.init(entity="amudide", project="TopK (Frequent Log)", config={f'{trainer_config["wandb_name"]}-{i}' : trainer_config for i, trainer_config in enumerate(trainer_configs)})
42 |
43 | trainSAE(buffer, trainer_configs=trainer_configs, save_dir='dictionaries', log_steps=1, steps=steps)
44 |
45 | print("Training finished. Evaluating SAE...", flush=True)
46 | for i, trainer_config in enumerate(trainer_configs):
47 | ae = AutoEncoderTopK.from_pretrained(f'dictionaries/{cfg_filename(trainer_config)}/ae.pt', k = trainer_config['k'], device=device)
48 | metrics = evaluate(ae, buffer, device=device)
49 | log = {}
50 | log.update({f'{trainer_config["wandb_name"]}-{i}/{k}' : v for k, v in metrics.items()})
51 | wandb.log(log, step=steps+1)
52 | wandb.finish()
--------------------------------------------------------------------------------