├── .gitignore
├── LICENSE
├── README.md
├── analysis.py
├── rare_freq_dir.pt
├── scratch.py
├── train.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | wandb/*
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Neel Nanda
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TLDR
2 |
3 | This is an open source replication of [Anthropic's Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features/index.html) paper. The autoencoder was trained on the gelu-1l model in TransformerLens, you can access two trained autoencoders and the model using [this tutorial](https://colab.research.google.com/drive/1u8larhpxy8w4mMsJiSBddNOzFGj7_RTn#scrollTo=MYrIYDEfBtbL).
4 |
5 | # Reading This Codebase
6 |
7 | This is a pretty scrappy training codebase, and won't run from the top. I mostly recommend reading the code and copying snippets. See also [Hoagy Cunningham's Github](https://github.com/HoagyC/sparse_coding).
8 |
9 | * `utils.py` contains various utils to define the Autoencoder, data Buffer and training data.
10 | * Toggle `loading_data_first_time` to True to load and process the text data used to run the model and generate acts
11 | * `train.py` is a scrappy training script
12 | * `cfg["remove_rare_dir"]` was an experiment in training an autoencoder whose features were all orthogonal to the shared direction among rare features, those lines of code can be ignored and weren't used for the open source autoencoders.
13 | * There was a bug in the code to set the decoder weights to have unit norm - it makes the gradients orthogonal, but I forgot to *also* set the norm to be 1 again after each gradient update (turns out a vector of unit norm plus a perpendicular vector does not remain unit norm!). I think I have now fixed the bug.
14 | * `analysis.py` is a scrappy set of experiments for exploring the autoencoder. I recommend reading the Colab tutorial instead for something cleaner and better commented.
15 |
16 | Setup Notes:
17 |
18 | * Create data - you'll need to set the flag loading_data_first_time to True in utils.py , note that this downloads the training mix of gelu-1l and if using eg the Pythia models you'll need different data (I recommend https://huggingface.co/datasets/monology/pile-uncopyrighted )
19 | * A bunch of folders are hard coded to be /workspace/..., change this for your system.
20 | * Create a checkpoints dir in /workspace/1L-Sparse-Autoencoder/checkpoints
21 |
22 | * If you train an autoencoder and want to share the weights, copy the final checkpoints to a new folder, use upload_folder_to_hf to upload to HuggingFace, create your own repo. Run huggingface-cli login to login, and apt-get install git-lfs and then git lfs install
--------------------------------------------------------------------------------
/analysis.py:
--------------------------------------------------------------------------------
1 | # %%
2 | new_cfg = {
3 | "layer": 1,
4 | }
5 | from utils import *
6 | torch.set_grad_enabled(False)
7 | cfg.update(new_cfg)
8 | post_init_cfg(cfg)
9 | # %%
10 | tokens = all_tokens[:256]
11 | _, cache = model.run_with_cache(tokens, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"])
12 | acts = cache[cfg["act_name"]]
13 | acts_flattened = acts.reshape(-1, cfg["act_size"])
14 | encoder = AutoEncoder.load("gelu-2l_L1_16384_mlp_out_50")
15 | hidden_acts = F.relu((acts_flattened - encoder.b_dec) @ encoder.W_enc + encoder.b_enc)
16 | reconstr = hidden_acts @ encoder.W_dec + encoder.b_dec
17 | l2_loss = (acts_flattened - reconstr).pow(2).sum(-1).mean(0)
18 | l1_loss = encoder.l1_coeff * (hidden_acts.abs().sum())
19 | print(l2_loss, l1_loss)
20 | # %%
21 | freqs = get_freqs(25, encoder)
22 | # %%
23 | histogram((freqs+10**-6.5).log10(), histnorm="percent", title="Frequencies for Final Checkpoint", xaxis="Freq (Log10)", yaxis="Percent")
24 |
25 | # %%
26 | is_rare = freqs < 1e-4
27 |
28 | is_rare[0] = True
29 |
30 | rare_enc = encoder.W_enc[:, is_rare]
31 | rare_mean = rare_enc.mean(-1)
32 |
33 | cos_with_mean = (rare_mean @ rare_enc) / rare_mean.norm() / rare_enc.norm(dim=0)
34 | histogram(cos_with_mean, histnorm="percent", marginal="box", title="Cosine sim of rare features with mean rare direction", yaxis="percent", xaxis="Cosine Sim")
35 | proj_onto_mean = (rare_mean @ rare_enc) / rare_mean.norm()
36 | histogram(proj_onto_mean, histnorm="percent", marginal="box", title="Projection of rare features onto mean rare direction", yaxis="percent", xaxis="Projection")
37 |
38 | print((cos_with_mean > 0.95).float().mean())
39 |
40 | scatter(x=proj_onto_mean, y=(freqs[is_rare]+10**-6.5).log10())
41 | scatter(x=cos_with_mean, y=(freqs[is_rare]+10**-6.5).log10())
42 | # %%
43 |
44 | feature_df = pd.DataFrame()
45 | feature_kurtosis = scipy.stats.kurtosis(to_numpy(encoder.W_dec.T))
46 | feature_df["neuron_kurt"] = to_numpy(feature_kurtosis)
47 | feature_kurtosis_enc = scipy.stats.kurtosis(to_numpy(encoder.W_enc))
48 | feature_df["neuron_kurt_enc"] = to_numpy(feature_kurtosis_enc)
49 | feature_df["is_rare"] = to_numpy(is_rare)
50 | feature_df["freq"] = to_numpy(freqs)
51 | # %%
52 | # encoder2 = AutoEncoder.load(47)
53 | # freqs2 = get_freqs(25, encoder2)
54 | # is_rare2 = freqs2 < 1e-4
55 | # rare_enc2 = encoder2.W_enc[:, is_rare2]
56 | # rare_mean2 = rare_enc2.mean(-1)
57 | # cos_with_mean2 = (rare_mean2 @ rare_enc2) / rare_mean2.norm() / rare_enc2.norm(dim=0)
58 | # histogram(cos_with_mean2, histnorm="percent", marginal="box", title="Cosine sim of rare features with mean rare direction", yaxis="percent", xaxis="Cosine Sim")
59 | # # %%
60 | # rare_mean2 @ rare_mean / rare_mean2.norm() / rare_mean.norm()
61 | # # %%
62 | # def basic_feature_vis(text, feature_index, max_val=0):
63 | # feature_in = encoder.W_enc[:, feature_index]
64 | # feature_bias = encoder.b_enc[feature_index]
65 | # _, cache = model.run_with_cache(text, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"])
66 | # acts = cache[cfg["act_name"]][0]
67 | # feature_acts = F.relu((acts - encoder.b_dec) @ feature_in + feature_bias)
68 | # if max_val==0:
69 | # max_val = max(1e-7, feature_acts.max().item())
70 | # # print(max_val)
71 | # # if min_val==0:
72 | # # min_val = min(-1e-7, feature_acts.min().item())
73 | # return basic_token_vis_make_str(text, feature_acts, max_val)
74 | # def basic_token_vis_make_str(strings, values, max_val=None):
75 | # if not isinstance(strings, list):
76 | # strings = model.to_str_tokens(strings)
77 | # values = to_numpy(values)
78 | # if max_val is None:
79 | # max_val = values.max()
80 | # # if min_val is None:
81 | # # min_val = values.min()
82 | # header_string = f"
Max Range {values.max():.4f} Min Range: {values.min():.4f}
"
83 | # header_string += f"Set Max Range {max_val:.4f}
"
84 | # # values[values>0] = values[values>0]/ma|x_val
85 | # # values[values<0] = values[values<0]/abs(min_val)
86 | # body_string = nutils.create_html(strings, values, max_value=max_val, return_string=True)
87 | # return header_string + body_string
88 | # display(HTML(basic_token_vis_make_str(tokens[0, :10], acts[0, :10, 7], 0.1)))
89 | # display(HTML(basic_feature_vis("I really like things", 7)))
90 | # %%
91 | # The `with gr.Blocks() as demo:` syntax just creates a variable called demo containing all these components
92 | # import gradio as gr
93 | # try:
94 | # demos[0].close()
95 | # except:
96 | # pass
97 | # demos = [None]
98 | # def make_feature_vis_gradio(batch, pos, feature_id):
99 | # try:
100 | # demos[0].close()
101 | # except:
102 | # pass
103 | # with gr.Blocks() as demo:
104 | # gr.HTML(value=f"Hacky Interactive Neuroscope for gelu-1l")
105 | # # The input elements
106 | # with gr.Row():
107 | # with gr.Column():
108 | # text = gr.Textbox(label="Text", value=model.to_string(tokens[batch, 1:pos+1]))
109 | # # Precision=0 makes it an int, otherwise it's a float
110 | # # Value sets the initial default value
111 | # feature_index = gr.Number(
112 | # label="Feature Index", value=feature_id, precision=0
113 | # )
114 | # # # If empty, these two map to None
115 | # max_val = gr.Number(label="Max Value", value=None)
116 | # # min_val = gr.Number(label="Min Value", value=None)
117 | # inputs = [text, feature_index, max_val]
118 | # with gr.Row():
119 | # with gr.Column():
120 | # # The output element
121 | # out = gr.HTML(label="Neuron Acts", value=basic_feature_vis(model.to_string(tokens[batch, 1:pos+1]), feature_id))
122 | # for inp in inputs:
123 | # inp.change(basic_feature_vis, inputs, out)
124 | # demo.launch(share=True)
125 | # demos[0] = demo
126 | # %%
127 |
128 | # values, indices = (hidden_acts[:, is_rare]>0).float().mean(-1).sort()
129 | # # histogram((acts_flattened @ rare_mean))
130 | # token_df = nutils.make_token_df(tokens)
131 | # token_df["rare_proj"] = to_numpy(acts_flattened @ rare_mean)
132 | # token_df["frac_rare_active"] = to_numpy((hidden_acts[:, is_rare]>0).float().mean(-1))
133 | # sorted_token_df = token_df.sort_values("rare_proj", ascending=False)
134 | # for i in range(3):
135 | # b = sorted_token_df.batch.iloc[i]
136 | # p = sorted_token_df.pos.iloc[i]
137 | # print(f"Frac rare features active on final token: {sorted_token_df.frac_rare_active.iloc[i]:.2%}")
138 | # curr_tokens = tokens[b, :p+1]
139 | # values = acts[b, :p+1] @ rare_mean
140 | # nutils.create_html(model.to_str_tokens(curr_tokens), values)
141 |
142 | # %%
143 | # # %%
144 | # sorted_W_enc = encoder.W_enc.abs().sort(dim=0).values
145 | # sorted_W_enc_sq = sorted_W_enc.pow(2)
146 | # sorted_W_enc_sq_sum = sorted_W_enc.pow(2).sum(0)
147 | # for k in [1, 2, 5, 10, 100]:
148 | # feature_df[f"fve_top_{k}"] = to_numpy(sorted_W_enc_sq[-k:, :].sum(0) / sorted_W_enc_sq_sum)
149 | # feature_df["fve_next_9"] = feature_df["fve_top_10"] - feature_df["fve_top_1"]
150 | # feature_df.sort_values("neuron_kurt", ascending=False)
151 | # # %%
152 | # px.histogram(feature_df, x="fve_top_1", histnorm="percent", cumulative=True, marginal="box").show()
153 | # px.histogram(feature_df, x="fve_top_5", histnorm="percent", cumulative=True, marginal="box").show()
154 | # px.histogram(feature_df, x="fve_top_10", histnorm="percent", cumulative=True, marginal="box").show()
155 | # px.histogram(feature_df, x="fve_top_100", marginal="box").show()
156 | # # %%
157 | # px.scatter(feature_df.query("~is_rare"), x="fve_top_1", y="fve_next_9", hover_name=feature_df.query("~is_rare").index, color_continuous_scale="Portland", marginal_x="histogram", marginal_y="histogram", title="Fraction of Squared Sum Explained by Top Neuron vs Next 9 Neurons", opacity=0.2, labels={"fve_top_1":"Frac Explained by Top Neuron", "fve_next_9":"Frac Explained by Next 9 Neurons"})
158 | # # %%
159 | # sorted_W_dec = encoder.W_dec.T.abs().sort(dim=0).values
160 | # sorted_W_dec_sq = sorted_W_dec.pow(2)
161 | # sorted_W_dec_sq_sum = sorted_W_dec.pow(2).sum(0)
162 | # for k in [1, 2, 5, 10, 100]:
163 | # feature_df[f"fve_top_{k}_dec"] = to_numpy(sorted_W_dec_sq[-k:, :].sum(0) / sorted_W_dec_sq_sum)
164 | # feature_df["fve_next_9_dec"] = feature_df["fve_top_10_dec"] - feature_df["fve_top_1_dec"]
165 | # px.scatter(feature_df.query("~is_rare"), x='fve_top_1', y="fve_top_1_dec").show()
166 | # # %%
167 | # px.scatter(feature_df.query("~is_rare"), x="fve_top_1_dec", y="fve_next_9_dec", hover_name=feature_df.query("~is_rare").index, color_continuous_scale="Portland", marginal_x="histogram", marginal_y="histogram", title="Fraction of Squared Sum Explained by Top Neuron vs Next 9 Neurons", opacity=0.2, labels={"fve_top_1_dec":"Frac Explained by Top Neuron", "fve_next_9_dec":"Frac Explained by Next 9 Neurons"})
168 | # # %%
169 | # feature_df["1-sparse"] = (feature_df["fve_top_1_dec"]>0.35) & (feature_df["fve_next_9_dec"]<0.1)
170 | # feature_df["10-sparse"] = (feature_df["fve_top_10_dec"]>0.35) & (~feature_df["1-sparse"])
171 | # print(f"Frac 1 sparse: {feature_df['1-sparse'].mean():.3f}")
172 | # print(f"Frac 10 sparse: {feature_df['10-sparse'].mean():.3f}")
173 | # def f(row):
174 | # if row["1-sparse"]:
175 | # return "1-sparse"
176 | # elif row["10-sparse"]:
177 | # return "10-sparse"
178 | # else:
179 | # return "dense"
180 | # feature_df["sparsity_label"] = feature_df.apply(f, axis=1)
181 | # px.scatter(feature_df.query("~is_rare"), x="fve_top_1_dec", y="fve_next_9_dec", hover_name=feature_df.query("~is_rare").index, color="sparsity_label", color_continuous_scale="Portland", marginal_x="histogram", marginal_y="histogram", title="Fraction of Squared Decoder Sum Explained by Top Neuron vs Next 9 Neurons", opacity=0.2, labels={"fve_top_1_dec":"Frac Explained by Top Neuron", "fve_next_9_dec":"Frac Explained by Next 9 Neurons"})
182 | # # %%
183 | # px.histogram(feature_df.query("~is_rare"), x="fve_top_1", log_y=True)
184 | # # %%
185 | # feature_df_baseline = pd.DataFrame()
186 | # rand_W_dec = torch.randn_like(encoder.W_dec.T)
187 | # feature_kurtosis = scipy.stats.kurtosis(to_numpy(rand_W_dec))
188 | # feature_df_baseline["neuron_kurt"] = to_numpy(feature_kurtosis)
189 | # feature_df_baseline["is_rare"] = to_numpy(is_rare)
190 | # # %%
191 | # sorted_W_dec = rand_W_dec.abs().sort(dim=0).values
192 | # sorted_W_dec_sq = sorted_W_dec.pow(2)
193 | # sorted_W_dec_sq_sum = sorted_W_dec.pow(2).sum(0)
194 | # for k in [1, 2, 5, 10, 100]:
195 | # feature_df_baseline[f"fve_top_{k}"] = to_numpy(sorted_W_dec_sq[-k:, :].sum(0) / sorted_W_dec_sq_sum)
196 | # feature_df_baseline["fve_next_9"] = feature_df_baseline["fve_top_10"] - feature_df_baseline["fve_top_1"]
197 | # feature_df_baseline.sort_values("neuron_kurt", ascending=False)
198 | # # %%
199 | # px.histogram(feature_df_baseline, x="fve_top_1", histnorm="percent", cumulative=True, marginal="box").show()
200 | # px.histogram(feature_df_baseline, x="fve_top_5", histnorm="percent", cumulative=True, marginal="box").show()
201 | # px.histogram(feature_df_baseline, x="fve_top_10", histnorm="percent", cumulative=True, marginal="box").show()
202 | # px.histogram(feature_df_baseline, x="fve_top_100", marginal="box").show()
203 | # # %%
204 | # px.scatter(feature_df_baseline.query("~is_rare"), x="fve_top_1", y="fve_next_9", hover_name=feature_df_baseline.query("~is_rare").index, color_continuous_scale="Portland", marginal_x="histogram", marginal_y="histogram", title="")
205 | # # %%
206 | # temp_df = pd.concat([feature_df_baseline.query("~is_rare"), feature_df.query("~is_rare")])
207 | # temp_df = temp_df.reset_index(drop=True)
208 | # temp_df["category"] = ["baseline"]*(len(temp_df)//2) + ["real"]*(len(temp_df)//2)
209 | # px.histogram(temp_df, "neuron_kurt", color="category", barmode="overlay", marginal="box", range_x=(-5, 50), nbins=5000, title="Neuron Kurtosis (real vs random baseline, clipped at 50)")
210 |
211 | # # %%
212 | # feature_df["enc_dec_sim"] = to_numpy((encoder.W_dec * encoder.W_enc.T).sum(-1) / encoder.W_enc.norm(dim=0) / encoder.W_dec.norm(dim=-1))
213 | # px.histogram(feature_df.query("~is_rare"), x="enc_dec_sim", title="Encoder Decoder Cosine Sim (non-rare features)")
214 | # # %%
215 | # U, S, Vh = torch.linalg.svd((model.W_out[0]))
216 | # print(U.shape)
217 | # line(S)
218 | # # %%
219 | # W_enc_svd = encoder.W_enc.T @ U
220 | # W_enc_svd_null = W_enc_svd[:, 512:].pow(2).sum(-1)
221 | # W_enc_svd_all = W_enc_svd[:, :].pow(2).sum(-1)
222 | # feature_df["enc_null_frac"] = to_numpy(W_enc_svd_null / W_enc_svd_all)
223 |
224 | # W_dec_svd = encoder.W_dec @ U
225 | # W_dec_svd_null = W_dec_svd[:, 512:].pow(2).sum(-1)
226 | # W_dec_svd_all = W_dec_svd[:, :].pow(2).sum(-1)
227 | # feature_df["dec_null_frac"] = to_numpy(W_dec_svd_null / W_dec_svd_all)
228 | # # px.histogram(feature_df, x="enc_null_frac", color="is_rare", barmode="overlay").show()
229 | # # px.histogram(feature_df, x="dec_null_frac", color="is_rare", barmode="overlay").show()
230 | # # px.scatter(feature_df, x="dec_null_frac", y="enc_null_frac", color="is_rare").show()
231 | # fig = px.histogram(feature_df.query("~is_rare"), x=["enc_null_frac", "dec_null_frac"], barmode="overlay", marginal="box", title="Fraction of feature in W_out null space (non-rare)")
232 | # fig.add_vline(x=0.75, line_dash="dash", line_color="gray")
233 | # # %%
234 | # fig = px.histogram(feature_df.query("~is_rare"), x=["enc_null_frac", "enc_null_frac_baseline"], barmode="overlay", marginal="box", title="Fraction of feature in W_out null space (non-rare)")
235 | # fig.add_vline(x=0.75, line_dash="dash", line_color="gray").show()
236 | # fig = px.histogram(feature_df.query("~is_rare"), x=["dec_null_frac", "dec_null_frac_baseline"], barmode="overlay", marginal="box", title="Fraction of feature in W_out null space (non-rare)")
237 | # fig.add_vline(x=0.75, line_dash="dash", line_color="gray")
238 |
239 | # %%
240 | # print(feature_df["enc_null_frac_baseline"].mean())
241 | # print(feature_df["enc_null_frac_baseline"].std())
242 | # print(feature_df["dec_null_frac_baseline"].mean())
243 | # print(feature_df["dec_null_frac_baseline"].std())
244 |
245 | # %%
246 | def basic_feature_vis(text, feature_index, max_val=0):
247 | feature_in = encoder.W_enc[:, feature_index]
248 | feature_bias = encoder.b_enc[feature_index]
249 | _, cache = model.run_with_cache(text, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"])
250 | acts = cache[cfg["act_name"]][0]
251 | feature_acts = F.relu((acts - encoder.b_dec) @ feature_in + feature_bias)
252 | if max_val==0:
253 | max_val = max(1e-7, feature_acts.max().item())
254 | # print(max_val)
255 | # if min_val==0:
256 | # min_val = min(-1e-7, feature_acts.min().item())
257 | return basic_token_vis_make_str(text, feature_acts, max_val)
258 | def basic_token_vis_make_str(strings, values, max_val=None):
259 | if not isinstance(strings, list):
260 | strings = model.to_str_tokens(strings)
261 | values = to_numpy(values)
262 | if max_val is None:
263 | max_val = values.max()
264 | # if min_val is None:
265 | # min_val = values.min()
266 | header_string = f"Max Range {values.max():.4f} Min Range: {values.min():.4f}
"
267 | header_string += f"Set Max Range {max_val:.4f}
"
268 | # values[values>0] = values[values>0]/ma|x_val
269 | # values[values<0] = values[values<0]/abs(min_val)
270 | body_string = nutils.create_html(strings, values, max_value=max_val, return_string=True)
271 | return header_string + body_string
272 |
273 | # The `with gr.Blocks() as demo:` syntax just creates a variable called demo containing all these components
274 | import gradio as gr
275 | try:
276 | demos[0].close()
277 | except:
278 | pass
279 | demos = [None]
280 | def make_feature_vis_gradio(batch, pos, feature_id):
281 | try:
282 | demos[0].close()
283 | except:
284 | pass
285 | with gr.Blocks() as demo:
286 | gr.HTML(value=f"Hacky Interactive Neuroscope for gelu-1l")
287 | # The input elements
288 | with gr.Row():
289 | with gr.Column():
290 | text = gr.Textbox(label="Text", value=model.to_string(tokens[batch, 1:pos+1]))
291 | # Precision=0 makes it an int, otherwise it's a float
292 | # Value sets the initial default value
293 | feature_index = gr.Number(
294 | label="Feature Index", value=feature_id, precision=0
295 | )
296 | # # If empty, these two map to None
297 | max_val = gr.Number(label="Max Value", value=None)
298 | # min_val = gr.Number(label="Min Value", value=None)
299 | inputs = [text, feature_index, max_val]
300 | with gr.Row():
301 | with gr.Column():
302 | # The output element
303 | out = gr.HTML(label="Neuron Acts", value=basic_feature_vis(model.to_string(tokens[batch, 1:pos+1]), feature_id))
304 | for inp in inputs:
305 | inp.change(basic_feature_vis, inputs, out)
306 | demo.launch(share=True)
307 | demos[0] = demo
308 | # %%
309 | vocab_df = pd.DataFrame({
310 | "token": np.arange(d_vocab),
311 | "str_token": nutils.process_tokens(np.arange(d_vocab)),
312 | })
313 | vocab_df["is_upper"] = vocab_df["str_token"].apply(lambda s: s!=s.lower() and s==s.upper())
314 | vocab_df["is_word"] = vocab_df["str_token"].apply(lambda s: s.replace(nutils.SPACE, "").isalpha())
315 | vocab_df["has_space"] = vocab_df["str_token"].apply(lambda s: len(s)>0 and s[0]==nutils.SPACE)
316 | vocab_df["is_capital"] = vocab_df["str_token"].apply(lambda s: len(s)>0 and ((s[0]==nutils.SPACE and s[1:]==s[1:].capitalize()) or (s[0]!=nutils.SPACE and s==s.capitalize())))
317 | vocab_df
318 | # %%
319 | torch.set_grad_enabled(False)
320 | feature_U = (encoder.W_dec) @ model.W_U
321 | vocab_kurts = scipy.stats.kurtosis(to_numpy(feature_U.T))
322 | feature_df["vocab_kurt"] = vocab_kurts
323 | # %%
324 | tokens = all_tokens[:1024]
325 | _, cache = model.run_with_cache(tokens, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"])
326 | del _
327 | acts = cache[cfg["act_name"]]
328 | acts_flattened = acts.reshape(-1, cfg["act_size"])
329 | hidden_acts = F.relu((acts_flattened - encoder.b_dec) @ encoder.W_enc + encoder.b_enc)
330 | # %%
331 | vocab_has_space = torch.tensor(vocab_df.has_space.values).cuda()
332 | token_has_space = vocab_has_space[tokens.flatten()]
333 | # %%
334 | f_id = 2
335 | print(feature_df.loc[f_id])
336 |
337 | token_df = nutils.make_token_df(tokens, 8)
338 | token_df["act"] = to_numpy(hidden_acts[:, f_id])
339 | token_df["active"] = to_numpy(hidden_acts[:, f_id]>0)
340 | token_df = token_df.sort_values("act", ascending=False)
341 | nutils.show_df(token_df.head(50))
342 |
343 | nutils.show_df(nutils.create_vocab_df(feature_U[f_id]).head(20))
344 | nutils.show_df(nutils.create_vocab_df(feature_U[f_id]).tail(10))
345 | px.histogram(to_numpy(feature_U[f_id]), color=vocab_df.has_space, barmode="overlay", histnorm="percent", marginal="box", labels={"title": "wDLA with(out) space", "color":"has_space"}, title="wDLA with(out) space").show()
346 | print("Frac fires with space", (token_has_space[hidden_acts[:, f_id]>0].sum() / len(token_has_space[hidden_acts[:, f_id]>0])).item())
347 | px.histogram(to_numpy(hidden_acts[:, f_id][hidden_acts[:, f_id]>0]), color=to_numpy(token_has_space[hidden_acts[:, f_id]>0]), barmode="overlay", marginal="box", title="Acts with(out) space").show()
348 | i = 0
349 | make_feature_vis_gradio(token_df.batch.iloc[i], token_df.pos.iloc[i], f_id)
350 | # %%
351 | # Feature 1
352 |
353 | ends_with_ed = vocab_df.str_token.apply(lambda s: s.endswith("ed")).values
354 | px.histogram(to_numpy(feature_U[f_id]), color=ends_with_ed, barmode="overlay", histnorm="percent", marginal="box", labels={"title": "wDLA with(out) space", "color":"has_ed"}, title="wDLA with(out) ed", hover_name=vocab_df.str_token).show()
355 | # %%
356 | # Feature 2
357 |
358 | # %%
359 |
360 |
361 | # hidden = hidden_acts[:, f_id].reshape(tokens.shape)
362 | # ave_firing = (hidden>0).float().mean(-1)
363 | # ave_act = (hidden).mean(-1)
364 | # big_fire_thresh = 0.2 * token_df.act.max()
365 | # ave_act_cond = (hidden).sum(-1) / ((hidden>0).float().sum(-1)+1e-7)
366 | # line([ave_firing, ave_act, ave_act_cond], line_labels=["Freq firing", "Ave act", "Ave act if firing"], title="Per batch summary statistics")
367 |
368 | # argmax_token = tokens.flatten()[hidden.flatten().argmax(-1).cpu()]
369 | # argmax_str_token = model.to_string(argmax_token)
370 | # print(argmax_token, argmax_str_token)
371 | # pos_token_df = token_df[token_df.act>0]
372 | # frac_of_fires_are_top_token = (pos_token_df.str_tokens==argmax_str_token).sum()/len(pos_token_df)
373 | # frac_big_firing_on_top_token = (pos_token_df.query(f"act>{big_fire_thresh}").str_tokens==argmax_str_token).sum()/len(pos_token_df.query(f"act>{big_fire_thresh}"))
374 | # frac_of_top_token_are_fires = (hidden.flatten().cpu()[tokens.flatten()==argmax_token]>0).float().mean().item()
375 | # print(f"{frac_of_fires_are_top_token=:.2%}")
376 | # print(f"{frac_big_firing_on_top_token=:.2%}")
377 | # print(f"{frac_of_top_token_are_fires=:.2%}")
378 | # print(f"Sample size = {(tokens.flatten()==argmax_token).sum().item()}")
379 |
380 | # line([encoder.W_enc[:, f_id], encoder.W_dec[f_id, :]], xaxis="Neuron", title="Weights in the neuron basis", line_labels=["encoder", "decoder"])
381 |
382 | # %%
383 |
384 | # %%
385 | temp_df = copy.deepcopy(vocab_df)
386 | def f(row):
387 | # print(row)
388 | if row.is_capital and row.has_space:
389 | return "Capital"
390 | elif not row.has_space and row.is_word:
391 | return "Fragment"
392 | else:
393 | return "Other"
394 | temp_df["cond"] = temp_df.apply(f, axis=1)
395 | temp_df["x"] = to_numpy(feature_U[f_id])
396 | px.histogram(temp_df, x="x", color="is_capital", barmode="overlay", marginal="box", histnorm="percent", hover_name="str_token").show()
397 | px.histogram(temp_df, x="x", color="cond", barmode="overlay", marginal="box", histnorm="percent", hover_name="str_token").show()
398 | # %%
399 |
--------------------------------------------------------------------------------
/rare_freq_dir.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/neelnanda-io/1L-Sparse-Autoencoder/bcae01328a2f41d24bd4a9160828f2fc22737f75/rare_freq_dir.pt
--------------------------------------------------------------------------------
/scratch.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | os.environ["TRANSFORMERS_CACHE"] = "/workspace/cache/"
4 | os.environ["DATASETS_CACHE"] = "/workspace/cache/"
5 | # %%
6 | from neel.imports import *
7 | from neel_plotly import *
8 | import wandb
9 | # %%
10 | import argparse
11 | def arg_parse_update_cfg(default_cfg):
12 | """
13 | Helper function to take in a dictionary of arguments, convert these to command line arguments, look at what was passed in, and return an updated dictionary.
14 |
15 | If in Ipython, just returns with no changes
16 | """
17 | if get_ipython() is not None:
18 | # Is in IPython
19 | print("In IPython - skipped argparse")
20 | return default_cfg
21 | cfg = dict(default_cfg)
22 | parser = argparse.ArgumentParser()
23 | for key, value in default_cfg.items():
24 | if type(value) == bool:
25 | # argparse for Booleans is broken rip. Now you put in a flag to change the default --{flag} to set True, --{flag} to set False
26 | if value:
27 | parser.add_argument(f"--{key}", action="store_false")
28 | else:
29 | parser.add_argument(f"--{key}", action="store_true")
30 |
31 | else:
32 | parser.add_argument(f"--{key}", type=type(value), default=value)
33 | args = parser.parse_args()
34 | parsed_args = vars(args)
35 | cfg.update(parsed_args)
36 | print("Updated config")
37 | print(json.dumps(cfg, indent=2))
38 | return cfg
39 | default_cfg = {
40 | "seed": 49,
41 | "batch_size": 4096,
42 | "buffer_mult": 384,
43 | "lr": 1e-4,
44 | "num_tokens": int(2e9),
45 | "l1_coeff": 3e-4,
46 | "beta1": 0.9,
47 | "beta2": 0.99,
48 | "dict_mult": 8,
49 | "seq_len": 128,
50 | "d_mlp": 2048,
51 | "enc_dtype":"fp32",
52 | "remove_rare_dir": True
53 | }
54 | cfg = arg_parse_update_cfg(default_cfg)
55 | cfg["model_batch_size"] = cfg["batch_size"] // cfg["seq_len"] * 16
56 | cfg["buffer_size"] = cfg["batch_size"] * cfg["buffer_mult"]
57 | cfg["buffer_batches"] = cfg["buffer_size"] // cfg["seq_len"]
58 | pprint.pprint(cfg)
59 | # %%
60 |
61 | SEED = cfg["seed"]
62 | GENERATOR = torch.manual_seed(SEED)
63 | DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
64 | np.random.seed(SEED)
65 | random.seed(SEED)
66 | torch.set_grad_enabled(True)
67 |
68 | model = HookedTransformer.from_pretrained("gelu-1l").to(DTYPES[cfg["enc_dtype"]])
69 |
70 | n_layers = model.cfg.n_layers
71 | d_model = model.cfg.d_model
72 | n_heads = model.cfg.n_heads
73 | d_head = model.cfg.d_head
74 | d_mlp = model.cfg.d_mlp
75 | d_vocab = model.cfg.d_vocab
76 | # %%
77 | @torch.no_grad()
78 | def get_mlp_acts(tokens, batch_size=1024):
79 | _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
80 | mlp_acts = cache[utils.get_act_name("post", 0)]
81 | mlp_acts = mlp_acts.reshape(-1, d_mlp)
82 | subsample = torch.randperm(mlp_acts.shape[0], generator=GENERATOR)[:batch_size]
83 | subsampled_mlp_acts = mlp_acts[subsample, :]
84 | return subsampled_mlp_acts, mlp_acts
85 | # sub, acts = get_mlp_acts(torch.arange(20).reshape(2, 10), batch_size=3)
86 | # sub.shape, acts.shape
87 | # %%
88 | SAVE_DIR = Path("/workspace/1L-Sparse-Autoencoder/checkpoints")
89 | class AutoEncoder(nn.Module):
90 | def __init__(self, cfg):
91 | super().__init__()
92 | d_hidden = cfg["d_mlp"] * cfg["dict_mult"]
93 | l1_coeff = cfg["l1_coeff"]
94 | dtype = DTYPES[cfg["enc_dtype"]]
95 | torch.manual_seed(cfg["seed"])
96 | self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_mlp, d_hidden, dtype=dtype)))
97 | self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, d_mlp, dtype=dtype)))
98 | self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
99 | self.b_dec = nn.Parameter(torch.zeros(d_mlp, dtype=dtype))
100 |
101 | self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
102 |
103 | self.d_hidden = d_hidden
104 | self.l1_coeff = l1_coeff
105 |
106 | self.to("cuda")
107 |
108 | def forward(self, x):
109 | x_cent = x - self.b_dec
110 | acts = F.relu(x_cent @ self.W_enc + self.b_enc)
111 | x_reconstruct = acts @ self.W_dec + self.b_dec
112 | l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
113 | l1_loss = self.l1_coeff * (acts.float().abs().sum())
114 | loss = l2_loss + l1_loss
115 | return loss, x_reconstruct, acts, l2_loss, l1_loss
116 |
117 | @torch.no_grad()
118 | def remove_parallel_component_of_grads(self):
119 | W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
120 | W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
121 | self.W_dec.grad -= W_dec_grad_proj
122 |
123 | def get_version(self):
124 | return 1+max([int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)])
125 |
126 | def save(self):
127 | version = self.get_version()
128 | torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt"))
129 | with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f:
130 | json.dump(cfg, f)
131 | print("Saved as version", version)
132 |
133 | @classmethod
134 | def load(cls, version):
135 | cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r")))
136 | pprint.pprint(cfg)
137 | self = cls(cfg=cfg)
138 | self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt")))
139 | return self
140 | # %%
141 |
142 |
143 |
144 | # %%
145 |
146 |
147 | # l1_coeff = 0.01
148 | # encoder = AutoEncoder(d_mlp*cfg["dict_mult"], l1_coeff=cfg['l1_coeff']).cuda()
149 | # loss, x_reconstruct, acts, l2_loss, l1_loss = encoder(sub)
150 | # print(loss, l2_loss, l1_loss)
151 |
152 | # loss.backward()
153 | # print(encoder.W_dec.grad.norm())
154 | # encoder.remove_parallel_component_of_grads()
155 | # print(encoder.W_dec.grad.norm())
156 | # print((encoder.W_dec.grad * encoder.W_dec).sum(-1))
157 | # %%
158 | # wandb.init(project="autoencoder", entity="neelnanda-io")
159 | # # %%
160 | # c4_urls = [f"https://huggingface.co/datasets/allenai/c4/resolve/main/en/c4-train.{i:0>5}-of-01024.json.gz" for i in range(901, 950)]
161 |
162 | # dataset = load_dataset("json", data_files=c4_urls, split="train")
163 |
164 | # dataset_name="c4"
165 | # dataset.save_to_disk(f"/workspace/data/{dataset_name}_text.hf")
166 | # # %%
167 | # print(dataset)
168 |
169 | # from transformer_lens.utils import tokenize_and_concatenate
170 |
171 | # tokenizer = model.tokenizer
172 |
173 | # tokens = tokenize_and_concatenate(dataset, tokenizer, streaming=False, num_proc=20, max_length=128)
174 | # tokens.save_to_disk(f"/workspace/data/{dataset_name}_tokens.hf")
175 | # %%
176 | def shuffle_data(all_tokens):
177 | print("Shuffled data")
178 | return all_tokens[torch.randperm(all_tokens.shape[0])]
179 |
180 | loading_data_first_time = False
181 | if loading_data_first_time:
182 | data = load_dataset("NeelNanda/c4-code-tokenized-2b", split="train")
183 | data.save_to_disk("/workspace/data/c4_code_tokenized_2b.hf")
184 | data.set_format(type="torch", columns=["tokens"])
185 | all_tokens = data["tokens"]
186 | all_tokens.shape
187 |
188 |
189 | all_tokens_reshaped = einops.rearrange(all_tokens, "batch (x seq_len) -> (batch x) seq_len", x=8, seq_len=128)
190 | all_tokens_reshaped[:, 0] = model.tokenizer.bos_token_id
191 | all_tokens_reshaped = all_tokens_reshaped[torch.randperm(all_tokens_reshaped.shape[0])]
192 | torch.save(all_tokens_reshaped, "/workspace/data/c4_code_2b_tokens_reshaped.pt")
193 | else:
194 | # data = datasets.load_from_disk("/workspace/data/c4_code_tokenized_2b.hf")
195 | all_tokens = torch.load("/workspace/data/c4_code_2b_tokens_reshaped.pt")
196 | all_tokens = shuffle_data(all_tokens)
197 |
198 | # %%
199 | class Buffer():
200 | def __init__(self, cfg):
201 | self.buffer = torch.zeros((cfg["buffer_size"], cfg["d_mlp"]), dtype=torch.bfloat16, requires_grad=False).cuda()
202 | self.cfg = cfg
203 | self.token_pointer = 0
204 | self.first = True
205 | self.refresh()
206 |
207 | @torch.no_grad()
208 | def refresh(self):
209 | self.pointer = 0
210 | with torch.autocast("cuda", torch.bfloat16):
211 | if self.first:
212 | num_batches = self.cfg["buffer_batches"]
213 | else:
214 | num_batches = self.cfg["buffer_batches"]//2
215 | self.first = False
216 | for _ in range(0, num_batches, self.cfg["model_batch_size"]):
217 | tokens = all_tokens[self.token_pointer:self.token_pointer+self.cfg["model_batch_size"]]
218 | _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
219 | mlp_acts = cache[utils.get_act_name("post", 0)].reshape(-1, self.cfg["d_mlp"])
220 | # print(tokens.shape, mlp_acts.shape, self.pointer, self.token_pointer)
221 | self.buffer[self.pointer: self.pointer+mlp_acts.shape[0]] = mlp_acts
222 | self.pointer += mlp_acts.shape[0]
223 | self.token_pointer += self.cfg["model_batch_size"]
224 | # if self.token_pointer > all_tokens.shape[0] - self.cfg["model_batch_size"]:
225 | # self.token_pointer = 0
226 |
227 | self.pointer = 0
228 | self.buffer = self.buffer[torch.randperm(self.buffer.shape[0]).cuda()]
229 |
230 | @torch.no_grad()
231 | def next(self):
232 | out = self.buffer[self.pointer:self.pointer+self.cfg["batch_size"]]
233 | self.pointer += self.cfg["batch_size"]
234 | if self.pointer > self.buffer.shape[0]//2 - self.cfg["batch_size"]:
235 | # print("Refreshing the buffer!")
236 | self.refresh()
237 | return out
238 |
239 | # buffer.refresh()
240 | # %%
241 |
242 | # %%
243 | def replacement_hook(mlp_post, hook, encoder):
244 | mlp_post_reconstr = encoder(mlp_post)[1]
245 | return mlp_post_reconstr
246 |
247 | def mean_ablate_hook(mlp_post, hook):
248 | mlp_post[:] = mlp_post.mean([0, 1])
249 | return mlp_post
250 |
251 | def zero_ablate_hook(mlp_post, hook):
252 | mlp_post[:] = 0.
253 | return mlp_post
254 |
255 | @torch.no_grad()
256 | def get_recons_loss(num_batches=5, local_encoder=None):
257 | if local_encoder is None:
258 | local_encoder = encoder
259 | loss_list = []
260 | for i in range(num_batches):
261 | tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]]
262 | loss = model(tokens, return_type="loss")
263 | recons_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replacement_hook, encoder=local_encoder))])
264 | # mean_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), mean_ablate_hook)])
265 | zero_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), zero_ablate_hook)])
266 | loss_list.append((loss, recons_loss, zero_abl_loss))
267 | losses = torch.tensor(loss_list)
268 | loss, recons_loss, zero_abl_loss = losses.mean(0).tolist()
269 |
270 | print(loss, recons_loss, zero_abl_loss)
271 | score = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
272 | print(f"{score:.2%}")
273 | # print(f"{((zero_abl_loss - mean_abl_loss)/(zero_abl_loss - loss)).item():.2%}")
274 | return score, loss, recons_loss, zero_abl_loss
275 | # print(get_recons_loss())
276 |
277 | # %%
278 | # Frequency
279 | @torch.no_grad()
280 | def get_freqs(num_batches=25, local_encoder=None):
281 | if local_encoder is None:
282 | local_encoder = encoder
283 | act_freq_scores = torch.zeros(local_encoder.d_hidden, dtype=torch.float32).cuda()
284 | total = 0
285 | for i in tqdm.trange(num_batches):
286 | tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]]
287 |
288 | _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
289 | mlp_acts = cache[utils.get_act_name("post", 0)]
290 | mlp_acts = mlp_acts.reshape(-1, d_mlp)
291 |
292 | hidden = local_encoder(mlp_acts)[2]
293 |
294 | act_freq_scores += (hidden > 0).sum(0)
295 | total+=hidden.shape[0]
296 | act_freq_scores /= total
297 | num_dead = (act_freq_scores==0).float().mean()
298 | print("Num dead", num_dead)
299 | return act_freq_scores
300 | # %%
301 | @torch.no_grad()
302 | def re_init(indices, encoder):
303 | new_W_enc = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_enc)))
304 | new_W_dec = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_dec)))
305 | new_b_enc = (torch.zeros_like(encoder.b_enc))
306 | print(new_W_dec.shape, new_W_enc.shape, new_b_enc.shape)
307 | encoder.W_enc.data[:, indices] = new_W_enc[:, indices]
308 | encoder.W_dec.data[indices, :] = new_W_dec[indices, :]
309 | encoder.b_enc.data[indices] = new_b_enc[indices]
310 | # %%
311 | encoder = AutoEncoder(cfg)
312 | buffer = Buffer(cfg)
313 | evil_dir = torch.load("/workspace/1L-Sparse-Autoencoder/evil_dir.pt")
314 | evil_dir.requires_grad = False
315 | # %%
316 | try:
317 | wandb.init(project="autoencoder", entity="neelnanda-io")
318 | num_batches = cfg["num_tokens"] // cfg["batch_size"]
319 | # model_num_batches = cfg["model_batch_size"] * num_batches
320 | encoder_optim = torch.optim.Adam(encoder.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"]))
321 | recons_scores = []
322 | act_freq_scores_list = []
323 | for i in tqdm.trange(num_batches):
324 | i = i % all_tokens.shape[0]
325 | # tokens = all_tokens[i:i+cfg["model_batch_size"]]
326 | # acts = get_mlp_acts(tokens, batch_size=cfg["batch_size"])[0].detach()
327 | acts = buffer.next()
328 | loss, x_reconstruct, mid_acts, l2_loss, l1_loss = encoder(acts)
329 | loss.backward()
330 | encoder.remove_parallel_component_of_grads()
331 | if cfg["remove_rare_dir"]:
332 | with torch.no_grad():
333 | encoder.W_enc.grad -= (evil_dir @ encoder.W_enc.grad)[None, :] * evil_dir[:, None]
334 | encoder_optim.step()
335 | encoder_optim.zero_grad()
336 | if cfg["remove_rare_dir"]:
337 | with torch.no_grad():
338 | encoder.W_enc -= (evil_dir @ encoder.W_enc)[None, :] * evil_dir[:, None]
339 | loss_dict = {"loss": loss.item(), "l2_loss": l2_loss.item(), "l1_loss": l1_loss.item()}
340 | del loss, x_reconstruct, mid_acts, l2_loss, l1_loss, acts
341 | if (i) % 100 == 0:
342 | wandb.log(loss_dict)
343 | print(loss_dict)
344 | if (i) % 1000 == 0:
345 | x = (get_recons_loss())
346 | print("Reconstruction:", x)
347 | recons_scores.append(x[0])
348 | freqs = get_freqs(5)
349 | act_freq_scores_list.append(freqs)
350 | # histogram(freqs.log10(), marginal="box", histnorm="percent", title="Frequencies")
351 | wandb.log({
352 | "recons_score": x[0],
353 | "dead": (freqs==0).float().mean().item(),
354 | "below_1e-6": (freqs<1e-6).float().mean().item(),
355 | "below_1e-5": (freqs<1e-5).float().mean().item(),
356 | })
357 | if (i+1) % 30000 == 0:
358 | encoder.save()
359 | wandb.log({"reset_neurons": 0.0})
360 | freqs = get_freqs(50)
361 | to_be_reset = (freqs<10**(-5.5))
362 | print("Resetting neurons!", to_be_reset.sum())
363 | re_init(to_be_reset, encoder)
364 | finally:
365 | encoder.save()
366 |
367 | # %%
368 |
369 | # # %%
370 | # acts = get_mlp_acts(all_tokens[i:i+cfg["model_batch_size"]], batch_size=cfg["batch_size"])[0].detach()
371 | # loss, x_reconstruct, mid_acts, l2_loss, l1_loss = encoder(acts)
372 | # # %%
373 | # acts.shape, x_reconstruct.shape
374 | # acts.norm(dim=-1).mean(), x_reconstruct.norm(dim=-1).mean(), (acts - x_reconstruct).norm(dim=-1).mean()
375 | # # %%
376 | # line(encoder.W_enc[:, :20].T)
377 | # line(encoder.W_dec[:20])
378 | # # %%
379 | # line(mid_acts.mean(0))
380 | # # %%
381 | freqs = get_freqs(50)
382 | histogram(freqs.log10(), marginal="box", histnorm="percent", title="Frequencies")
383 | freqs_5 = get_freqs(5)
384 | scatter(x=freqs_5.log10(), y=freqs.log10())
385 | # %%
386 |
387 | # %%
388 | (freqs<10**(-5.5)).float().mean()
389 | # %%
390 | @torch.no_grad()
391 | def re_init(indices, encoder):
392 | new_W_enc = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_enc)))
393 | new_W_dec = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_dec)))
394 | new_b_enc = (torch.zeros_like(encoder.b_enc))
395 | print(new_W_dec.shape, new_W_enc.shape, new_b_enc.shape)
396 | encoder.W_enc.data[:, indices] = new_W_enc[:, indices]
397 | encoder.W_dec.data[indices, :] = new_W_dec[indices, :]
398 | encoder.b_enc.data[indices] = new_b_enc[indices]
399 | freqs = get_freqs(50)
400 | to_be_reset = (freqs<10**(-5.5))
401 | re_init(to_be_reset, encoder)
402 | # %%
403 | x = (get_recons_loss())
404 | print("Reconstruction:", x)
405 | # recons_scores.append(x[0])
406 | freqs = get_freqs(5)
407 | # act_freq_scores_list.append(freqs)
408 | histogram((freqs+10**(-6.5)).log10(), marginal="box", histnorm="percent", title="Frequencies")
409 |
410 | # %%
411 |
412 | # %%
413 | enc2 = AutoEncoder.load(5)
414 | tokens = all_tokens[:32]
415 | acts = get_mlp_acts(tokens, batch_size=1)[1].detach()
416 | # acts = buffer.next()
417 | loss, x_reconstruct, mid_acts, l2_loss, l1_loss = enc2(acts)
418 | print(loss, l2_loss, l1_loss)
419 | # %%
420 | freqs = (mid_acts>0).float().mean(0)
421 | feature_df = pd.DataFrame({"freqs": to_numpy(freqs), "log_freq":to_numpy((freqs).log10())})
422 | feature_df[feature_df["log_freq"]>-5]
423 | # %%
424 | f_id = 18
425 | token_df = nutils.make_token_df(tokens)
426 | token_df["act"] = to_numpy(mid_acts[:, f_id])
427 | token_df["active"] = to_numpy(mid_acts[:, f_id]>0)
428 | nutils.show_df(token_df.sort_values("act", ascending=False).head(100))
429 |
430 | # %%
431 | line(token_df.groupby("batch")["act"].mean())
432 | line(token_df.groupby("batch")["active"].mean())
433 | # %%1
434 | act_freq_scores_list = []
435 | encoders = {}
436 | checkpoints = [25, 22, 21, 18, 15, 12, 9]
437 | for i in checkpoints:
438 | print(i)
439 | encoders[i] = AutoEncoder.load(i)
440 | freqs = get_freqs(20, encoders[i])
441 | act_freq_scores_list.append(freqs)
442 | histogram((freqs+10**(-6.5)).log10(), marginal="box", histnorm="percent", title=f"Frequencies for checkpoint {i}")
443 | # %%
444 | def num_tokens_per_checkpoint(c):
445 | if c==25:
446 | # return "2000M"
447 | return 2000
448 | else:
449 | return int(((c - 8) * 30000 * 4096)/1e6)
450 | line(x=list(range(26)), y=[num_tokens_per_checkpoint(c) for c in range(26)], title="Number of tokens per checkpoint")
451 | freqs = torch.stack(act_freq_scores_list).flatten()
452 | temp_df = pd.DataFrame({"freqs": to_numpy(freqs), "log_freq":to_numpy((freqs+10**(-6.5)).log10()),
453 | "checkpoint": [c for c in checkpoints for _ in range(encoders[25].d_hidden)],
454 | "million_tokens": [num_tokens_per_checkpoint(c) for c in checkpoints for _ in range(encoders[25].d_hidden)],
455 | })
456 | px.histogram(temp_df, color="million_tokens", x="log_freq", barmode="overlay", marginal="box", histnorm="percent", title="Frequencies for checkpoints")
457 | # %%
458 | scatter(x=temp_df.query("checkpoint==21").log_freq, y=temp_df.query("checkpoint==22").log_freq, marginal_x="box", marginal_y="box", title="Frequencies for checkpoints 21 and 22", include_diag=True, xaxis=f"{num_tokens_per_checkpoint(21)}M", yaxis=f"{num_tokens_per_checkpoint(22)}M")
459 | scatter(x=temp_df.query("checkpoint==21").log_freq, y=temp_df.query("checkpoint==25").log_freq, marginal_x="box", marginal_y="box", title="Frequencies for checkpoints 21 and 25", include_diag=True, xaxis=f"{num_tokens_per_checkpoint(21)}M", yaxis=f"{num_tokens_per_checkpoint(25)}M")
460 | scatter(x=temp_df.query("checkpoint==22").log_freq, y=temp_df.query("checkpoint==25").log_freq, marginal_x="box", marginal_y="box", title="Frequencies for checkpoints 22 and 25", include_diag=True, xaxis=f"{num_tokens_per_checkpoint(22)}M", yaxis=f"{num_tokens_per_checkpoint(25)}M")
461 | # %%
462 | torch.set_grad_enabled(False)
463 | # %%
464 | tokens = all_tokens[:256]
465 | _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
466 | mlp_acts = cache[utils.get_act_name("post", 0)]
467 | mlp_acts_flattened = mlp_acts.reshape(-1, cfg["d_mlp"])
468 | encoder = AutoEncoder.load(25)
469 | hidden_acts = F.relu((mlp_acts_flattened - encoder.b_dec) @ encoder.W_enc + encoder.b_enc)
470 | mlp_reconstr = hidden_acts @ encoder.W_dec + encoder.b_dec
471 | l2_loss = (mlp_acts_flattened - mlp_reconstr).pow(2).sum(-1).mean(0)
472 | l1_loss = encoder.l1_coeff * (hidden_acts.abs().sum())
473 | print(l2_loss, l1_loss)
474 | # %%
475 | freqs = get_freqs(25, encoder)
476 | # %%
477 | histogram((freqs+10**-6.5).log10(), histnorm="percent", title="Frequencies for Final Checkpoint", xaxis="Freq (Log10)", yaxis="Percent")
478 |
479 | # %%
480 | is_rare = freqs < 1e-4
481 |
482 |
483 | # %%
484 |
485 |
486 | # %%
487 |
488 | def replace_mlp_post(mlp_post, hook, replacement):
489 | mlp_post[:] = replacement
490 | return mlp_post
491 | recons_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replace_mlp_post, replacement=mlp_reconstr.reshape(mlp_acts.shape)))])
492 | zero_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replace_mlp_post, replacement=torch.zeros_like(mlp_acts)))])
493 | normal_loss = model(tokens, return_type="loss")
494 | mean_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replace_mlp_post, replacement=mlp_acts.mean(0, keepdim=True).mean(1, keepdim=True)))])
495 | print(f"{recons_loss.item()=}")
496 | print(f"{zero_abl_loss.item()=}")
497 | print(f"{normal_loss.item()=}")
498 | print(f"{mean_loss.item()=}")
499 | # %%
500 | new_losses = []
501 | min_freq_list = []
502 | for thresh in [-6.5, -5, -4.5, -4.4, -4.3, -4.2, -4.1, -4, -3, -2.5, -2, -1.5, -1, 0]:
503 | indices = freqs >= 10**thresh
504 | replacement = hidden_acts[:, indices] @ encoder.W_dec[indices, :] + encoder.b_dec
505 | new_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replace_mlp_post, replacement=replacement.reshape(mlp_acts.shape)))])
506 | new_losses.append(new_loss.item())
507 | min_freq_list.append(thresh)
508 | new_losses = np.array(new_losses)
509 | line(x=min_freq_list, y=new_losses, title="Loss vs minimum frequency")
510 | line(x=min_freq_list, y=(zero_abl_loss.item() - new_losses)/(zero_abl_loss.item() - normal_loss.item()), title="Scaled Loss vs minimum frequency final checkpoint", yaxis="% Loss Recovered", xaxis="Log freq floor")
511 | # %%
512 | # encoder2 = AutoEncoder.load(21)
513 | # hidden_acts = F.relu((mlp_acts_flattened - encoder2.b_dec) @ encoder2.W_enc + encoder2.b_enc)
514 | # mlp_reconstr = hidden_acts @ encoder2.W_dec + encoder2.b_dec
515 | # l2_loss = (mlp_acts_flattened - mlp_reconstr).pow(2).sum(-1).mean(0)
516 | # l1_loss = encoder2.l1_coeff * (hidden_acts.abs().sum())
517 | # print(l2_loss, l1_loss)
518 |
519 | # freqs = get_freqs(25, encoder2)
520 | # histogram((freqs+10**-6.5).log10(), barmode="overlay", marginal="box", histnorm="percent", title="Frequencies for checkpoint 21")
521 |
522 | # %%
523 | new_losses2 = []
524 | min_freq_list2 = []
525 | for thresh in [-6.5, -6, -5.5, -5, -4, -3, -2.5, -2, -1.5, -1, 0]:
526 | indices = freqs >= 10**thresh
527 | replacement = hidden_acts[:, indices] @ encoder2.W_dec[indices, :] + encoder2.b_dec
528 | new_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replace_mlp_post, replacement=replacement.reshape(mlp_acts.shape)))])
529 | new_losses2.append(new_loss.item())
530 | min_freq_list2.append(thresh)
531 | new_losses2 = np.array(new_losses2)
532 | line(x=min_freq_list2, y=new_losses2, title="Loss vs minimum frequency")
533 | line(x=min_freq_list2, y=(zero_abl_loss.item() - new_losses2)/(zero_abl_loss.item() - normal_loss.item()), title="Scaled Loss vs minimum frequency checkpoint 21", yaxis="% Loss Recovered", xaxis="Log freq floor")
534 |
535 | # %%
536 | fig = line(x=min_freq_list, y=(zero_abl_loss.item() - new_losses)/(zero_abl_loss.item() - normal_loss.item()), title="Scaled Reconstructed Loss vs minimum frequency checkpoint", yaxis="% Loss Recovered", xaxis="Log freq floor", return_fig=True, line_labels=["Final checkpoint"])
537 | fig.add_trace(go.Scatter(x=min_freq_list2, y=(zero_abl_loss.item() - new_losses2)/(zero_abl_loss.item() - normal_loss.item()), name="Checkpoint 21"))
538 | # %%
539 | def basic_feature_vis(text, feature_index, max_val=0):
540 | feature_in = encoder.W_enc[:, feature_index]
541 | feature_bias = encoder.b_enc[feature_index]
542 | _, cache = model.run_with_cache(text, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
543 | mlp_acts = cache[utils.get_act_name("post", 0)][0]
544 | feature_acts = F.relu((mlp_acts - encoder.b_dec) @ feature_in + feature_bias)
545 | if max_val==0:
546 | max_val = max(1e-7, feature_acts.max().item())
547 | # print(max_val)
548 | # if min_val==0:
549 | # min_val = min(-1e-7, feature_acts.min().item())
550 | return basic_token_vis_make_str(text, feature_acts, max_val)
551 | def basic_token_vis_make_str(strings, values, max_val=None):
552 | if not isinstance(strings, list):
553 | strings = model.to_str_tokens(strings)
554 | values = to_numpy(values)
555 | if max_val is None:
556 | max_val = values.max()
557 | # if min_val is None:
558 | # min_val = values.min()
559 | header_string = f"Max Range {values.max():.4f} Min Range: {values.min():.4f}
"
560 | header_string += f"Set Max Range {max_val:.4f}
"
561 | # values[values>0] = values[values>0]/ma|x_val
562 | # values[values<0] = values[values<0]/abs(min_val)
563 | body_string = nutils.create_html(strings, values, max_value=max_val, return_string=True)
564 | return header_string + body_string
565 | display(HTML(basic_token_vis_make_str(tokens[0, :10], mlp_acts[0, :10, 7], 0.1)))
566 | display(HTML(basic_feature_vis("I really like food food calories burgers eating is great", 7)))
567 | # %%
568 | # The `with gr.Blocks() as demo:` syntax just creates a variable called demo containing all these components
569 | import gradio as gr
570 | try:
571 | demos[0].close()
572 | except:
573 | pass
574 | demos = [None]
575 | def make_feature_vis_gradio(batch, pos, feature_id):
576 | try:
577 | demos[0].close()
578 | except:
579 | pass
580 | with gr.Blocks() as demo:
581 | gr.HTML(value=f"Hacky Interactive Neuroscope for gelu-1l")
582 | # The input elements
583 | with gr.Row():
584 | with gr.Column():
585 | text = gr.Textbox(label="Text", value=model.to_string(tokens[batch, 1:pos+1]))
586 | # Precision=0 makes it an int, otherwise it's a float
587 | # Value sets the initial default value
588 | feature_index = gr.Number(
589 | label="Feature Index", value=feature_id, precision=0
590 | )
591 | # # If empty, these two map to None
592 | max_val = gr.Number(label="Max Value", value=None)
593 | # min_val = gr.Number(label="Min Value", value=None)
594 | inputs = [text, feature_index, max_val]
595 | with gr.Row():
596 | with gr.Column():
597 | # The output element
598 | out = gr.HTML(label="Neuron Acts", value=basic_feature_vis(model.to_string(tokens[batch, 1:pos+1]), feature_id))
599 | for inp in inputs:
600 | inp.change(basic_feature_vis, inputs, out)
601 | demo.launch(share=True)
602 | demos[0] = demo
603 | # %%
604 | batch = 0
605 | feature_id = 7
606 | pos = 28
607 | make_feature_vis_gradio(batch, pos, feature_id)
608 | # %%
609 |
610 | # %%
611 | px.scatter(x=to_numpy(encoder.b_enc), y=to_numpy((freqs+10**-5).log10()), trendline="ols", labels={"x":"b_encoder", "y":"log10 freq", "color":"Is Rare"}, color=to_numpy(freqs<10**(-3.5)), title="Encoder bias vs frequency", marginal_x="histogram", marginal_y="histogram").show()
612 | px.scatter(x=to_numpy(encoder.W_enc.norm(dim=0)), y=to_numpy((freqs+10**-5).log10()), trendline="ols", labels={"x":"W_encoder.norm", "y":"log10 freq", "color":"Is Rare"}, color=to_numpy(freqs<10**(-3.5)), title="Encoder norm vs frequency", marginal_x="histogram", marginal_y="histogram").show()
613 | px.scatter(x=to_numpy(encoder.W_dec.norm(dim=-1)), y=to_numpy((freqs+10**-5).log10()), trendline="ols", labels={"x":"W_decoder.norm", "y":"log10 freq", "color":"Is Rare"}, color=to_numpy(freqs<10**(-3.5)), title="Decoder norm vs frequency", marginal_x="histogram", marginal_y="histogram").show()
614 | px.scatter(x=to_numpy(encoder.b_enc * encoder.W_dec.norm(dim=-1)), y=to_numpy((freqs+10**-5).log10()), trendline="ols", labels={"x":"b_encoder * W_dec.norm", "y":"log10 freq", "color":"Is Rare"}, color=to_numpy(freqs<10**(-3.5)), title="Weighted encoder bias vs frequency", marginal_x="histogram", marginal_y="histogram").show()
615 | px.scatter(x=to_numpy(encoder.W_enc.norm(dim=0) * encoder.W_dec.norm(dim=-1)), y=to_numpy((freqs+10**-5).log10()), trendline="ols", labels={"x":"W_enc.norm * W_dec.norm", "y":"log10 freq", "color":"Is Rare"}, color=to_numpy(freqs<10**(-3.5)), title="Encoder norm products vs frequency", marginal_x="histogram", marginal_y="histogram").show()
616 | # %%
617 | import huggingface_hub
618 | from pathlib import Path
619 | def push_to_hub(local_dir):
620 | if isinstance(local_dir, huggingface_hub.Repository):
621 | local_dir = local_dir.local_dir
622 | os.system(f"git -C {local_dir} add .")
623 | os.system(f"git -C {local_dir} commit -m 'Auto Commit'")
624 | os.system(f"git -C {local_dir} push")
625 |
626 |
627 | # move_folder_to_hub("v235_4L512W_solu_wikipedia", "NeelNanda/SoLU_4L512W_Wiki_Finetune", just_final=False)
628 | # def move_folder_to_hub(model_name, repo_name=None, just_final=True, debug=False):
629 | # if repo_name is None:
630 | # repo_name = model_name
631 | # model_folder = CHECKPOINT_DIR / model_name
632 | # repo_folder = CHECKPOINT_DIR / (model_name + "_repo")
633 | # repo_url = huggingface_hub.create_repo(repo_name, exist_ok=True)
634 | # repo = huggingface_hub.Repository(repo_folder, repo_url)
635 |
636 | # for file in model_folder.iterdir():
637 | # if not just_final or "final" in file.name or "config" in file.name:
638 | # if debug:
639 | # print(file.name)
640 | # file.rename(repo_folder / file.name)
641 | # push_to_hub(repo.local_dir)
642 | def upload_folder_to_hf(folder_path, repo_name=None, debug=False):
643 | folder_path = Path(folder_path)
644 | if repo_name is None:
645 | repo_name = folder_path.name
646 | repo_folder = folder_path.parent / (folder_path.name + "_repo")
647 | repo_url = huggingface_hub.create_repo(repo_name, exist_ok=True)
648 | repo = huggingface_hub.Repository(repo_folder, repo_url)
649 |
650 | for file in folder_path.iterdir():
651 | if debug:
652 | print(file.name)
653 | file.rename(repo_folder / file.name)
654 | push_to_hub(repo.local_dir)
655 | upload_folder_to_hf("/workspace/1L-Sparse-Autoencoder/checkpoints_copy_2", "sparse_autoencoder", True)
656 | # %%
657 | freqs = (hidden_acts>0).float().mean(0)
658 | feature_df = pd.DataFrame({"freqs": to_numpy(freqs), "log_freq":to_numpy((freqs).log10())})
659 | feature_df["is_common"] = feature_df["log_freq"]>-3.5
660 | neuron_kurts = scipy.stats.kurtosis(to_numpy(encoder.W_enc))
661 | feature_U = (encoder.W_dec @ model.W_out[0]) @ model.W_U
662 | vocab_kurts = scipy.stats.kurtosis(to_numpy(feature_U.T))
663 | feature_df["vocab_kurt"] = vocab_kurts
664 | feature_df["neuron_kurt"] = neuron_kurts
665 | neuron_frac_max = encoder.W_enc.max(dim=0).values / encoder.W_enc.abs().sum(0)
666 | feature_df["neuron_frac_max"] = to_numpy(neuron_frac_max)
667 | # %%
668 | encoder2 = AutoEncoder.load(47)
669 | freqs2 = get_freqs(5, encoder2)
670 | is_common2 = freqs2>10**-3.5
671 | is_common1 = freqs>10**-3.5
672 | cosine_sims = nutils.cos_mat(encoder.W_enc.T, encoder2.W_enc[:, is_common2].T)
673 | max_cosine_sim = cosine_sims.max(-1).values
674 | feature_df["max_cos"] = to_numpy(max_cosine_sim)
675 | feature_df
676 | # %%
677 | px.histogram(feature_df, x="neuron_kurt", marginal="box", color="is_common", histnorm="percent", title="Neuron Kurtosis", barmode="overlay", hover_name=feature_df.index).show()
678 | px.histogram(feature_df, x="neuron_frac_max", marginal="box", color="is_common", histnorm="percent", title="Neuron Frac Max", barmode="overlay", hover_name=feature_df.index).show()
679 | px.histogram(feature_df, x="vocab_kurt", marginal="box", color="is_common", histnorm="percent", title="Vocab Kurtosis", barmode="overlay", hover_name=feature_df.index).show()
680 | px.scatter(feature_df, x="neuron_kurt", y="neuron_frac_max", hover_name=feature_df.index).show()
681 | # %%
682 | top_features = feature_df.sort_values("neuron_kurt", ascending=False).head(20).index.tolist()
683 | line(encoder.W_enc[:, top_features].T, line_labels=top_features, title="Top Features in Neuron Basis", xaxis="Neuron")
684 | # # %%
685 | # f_id = 6
686 | # token_df = nutils.make_token_df(tokens, 8)
687 | # token_df["act"] = to_numpy(hidden_acts[:, f_id])
688 | # token_df["active"] = to_numpy(hidden_acts[:, f_id]>0)
689 | # token_df = token_df.sort_values("act", ascending=False)
690 | # nutils.show_df(token_df.head(100))
691 |
692 | # i = 0
693 | # make_feature_vis_gradio(token_df.batch.iloc[i], token_df.pos.iloc[i], f_id)
694 | # # %%
695 | # is_rare = ~feature_df.is_common.values
696 | # U, S, Vh = torch.linalg.svd(encoder.W_enc[:, ~is_rare])
697 | # line(S, title="Singular Values of common features")
698 | # histogram(U[:, :5], barmode="overlay", title="MLP side singular vectors common")
699 | # histogram(Vh[:, :5], barmode="overlay", title="Feature side singular vectors common")
700 |
701 | # U, S, Vh = torch.linalg.svd(encoder.W_enc[:, is_rare])
702 | # line(S, title="Singular Values of rare features")
703 | # histogram(U[:, :5], barmode="overlay", title="MLP side singular vectors rare")
704 | # histogram(Vh[:, :5], barmode="overlay", title="Feature side singular vectors rare")
705 |
706 | # token_df = nutils.make_token_df(tokens, 8, 3)
707 | # token_df["rare_ave_feature"] = to_numpy(mlp_acts_flattened @ U[:, 0] * S[0] * 0.01)
708 | # token_df["num_rare_active"] = to_numpy((hidden_acts[:, is_rare]>0).float().mean(-1))
709 | # token_df["num_com_active"] = to_numpy((hidden_acts[:, ~is_rare]>0).float().mean(-1))
710 | # token_df["mlp_act_norm"] = to_numpy(mlp_acts_flattened.norm(dim=-1))
711 | # nutils.show_df(token_df.sort_values("rare_ave_feature", ascending=False).head(50))
712 | # nutils.show_df(token_df.sort_values("rare_ave_feature", ascending=False).tail(50))
713 | # svd_logit_lens = U[:, 0] @ model.W_out[0] @ model.W_U
714 | # nutils.show_df(nutils.create_vocab_df(svd_logit_lens).head(20))
715 | # # %%
716 | # nutils.show_df(feature_df.query("is_common").head(20))
717 | # %%
718 | torch.set_grad_enabled(False)
719 | f_id = 12
720 | print(feature_df.loc[f_id])
721 | tokens = all_tokens[:512]
722 |
723 | _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0))
724 | mlp_acts = cache[utils.get_act_name("post", 0)]
725 | mlp_acts_flattened = mlp_acts.reshape(-1, cfg["d_mlp"])
726 | hidden_acts = F.relu((mlp_acts_flattened - encoder.b_dec) @ encoder.W_enc + encoder.b_enc)
727 |
728 | token_df = nutils.make_token_df(tokens, 8)
729 | token_df["act"] = to_numpy(hidden_acts[:, f_id])
730 | token_df["active"] = to_numpy(hidden_acts[:, f_id]>0)
731 | token_df = token_df.sort_values("act", ascending=False)
732 | nutils.show_df(token_df.head(50))
733 | hidden = hidden_acts[:, f_id].reshape(tokens.shape)
734 | ave_firing = (hidden>0).float().mean(-1)
735 | ave_act = (hidden).mean(-1)
736 | big_fire_thresh = 0.2 * token_df.act.max()
737 | ave_act_cond = (hidden).sum(-1) / ((hidden>0).float().sum(-1)+1e-7)
738 | line([ave_firing, ave_act, ave_act_cond], line_labels=["Freq firing", "Ave act", "Ave act if firing"], title="Per batch summary statistics")
739 |
740 | argmax_token = tokens.flatten()[hidden.flatten().argmax(-1).cpu()]
741 | argmax_str_token = model.to_string(argmax_token)
742 | print(argmax_token, argmax_str_token)
743 | pos_token_df = token_df[token_df.act>0]
744 | frac_of_fires_are_top_token = (pos_token_df.str_tokens==argmax_str_token).sum()/len(pos_token_df)
745 | frac_big_firing_on_top_token = (pos_token_df.query(f"act>{big_fire_thresh}").str_tokens==argmax_str_token).sum()/len(pos_token_df.query(f"act>{big_fire_thresh}"))
746 | frac_of_top_token_are_fires = (hidden.flatten().cpu()[tokens.flatten()==argmax_token]>0).float().mean().item()
747 | print(f"{frac_of_fires_are_top_token=:.2%}")
748 | print(f"{frac_big_firing_on_top_token=:.2%}")
749 | print(f"{frac_of_top_token_are_fires=:.2%}")
750 | print(f"Sample size = {(tokens.flatten()==argmax_token).sum().item()}")
751 |
752 | line([encoder.W_enc[:, f_id], encoder.W_dec[f_id, :]], xaxis="Neuron", title="Weights in the neuron basis", line_labels=["encoder", "decoder"])
753 |
754 | nutils.show_df(nutils.create_vocab_df(feature_U[f_id]).head(20))
755 | nutils.show_df(nutils.create_vocab_df(feature_U[f_id]).tail(10))
756 | i = 0
757 | make_feature_vis_gradio(token_df.batch.iloc[i], token_df.pos.iloc[i], f_id)
758 | # %%
759 | # %%
760 | strings = [
761 | "and| they|",
762 | "and| you|",
763 | "and| we|",
764 | "and| it|",
765 | "and| I|",
766 | "and| she|",
767 | "but| they|",
768 | "but| you|",
769 | "but| we|",
770 | "but| it|",
771 | "but| I|",
772 | "but| she|",
773 | "or| they|",
774 | "or| you|",
775 | "or| we|",
776 | "or| it|",
777 | "or| I|",
778 | "or| she|",
779 | # "but| they|",
780 | # "but| you|",
781 | # "but| we|",
782 | # "but| it|",
783 | # "but| I|",
784 | # "but| she|",
785 | ]
786 | token_df["and_pronoun"] = [any(x in c for x in strings) for c in token_df.context]
787 | px.histogram(token_df, x="act", marginal="box", color="and_pronoun", histnorm="percent", title="Neuron activation for pronouns", barmode="overlay", hover_name="context").show()
788 | # %%
789 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # %%
2 | from utils import *
3 | # %%
4 | encoder = AutoEncoder(cfg)
5 | buffer = Buffer(cfg)
6 | # Code used to remove the "rare freq direction", the shared direction among the ultra low frequency features.
7 | # I experimented with removing it and retraining the autoencoder.
8 | if cfg["remove_rare_dir"]:
9 | rare_freq_dir = torch.load("rare_freq_dir.pt")
10 | rare_freq_dir.requires_grad = False
11 |
12 | # %%
13 | try:
14 | wandb.init(project="autoencoder", entity="neelnanda-io")
15 | num_batches = cfg["num_tokens"] // cfg["batch_size"]
16 | # model_num_batches = cfg["model_batch_size"] * num_batches
17 | encoder_optim = torch.optim.Adam(encoder.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"]))
18 | recons_scores = []
19 | act_freq_scores_list = []
20 | for i in tqdm.trange(num_batches):
21 | i = i % all_tokens.shape[0]
22 | acts = buffer.next()
23 | loss, x_reconstruct, mid_acts, l2_loss, l1_loss = encoder(acts)
24 | loss.backward()
25 | encoder.make_decoder_weights_and_grad_unit_norm()
26 | encoder_optim.step()
27 | encoder_optim.zero_grad()
28 | loss_dict = {"loss": loss.item(), "l2_loss": l2_loss.item(), "l1_loss": l1_loss.item()}
29 | del loss, x_reconstruct, mid_acts, l2_loss, l1_loss, acts
30 | if (i) % 100 == 0:
31 | wandb.log(loss_dict)
32 | print(loss_dict)
33 | if (i) % 1000 == 0:
34 | x = (get_recons_loss(local_encoder=encoder))
35 | print("Reconstruction:", x)
36 | recons_scores.append(x[0])
37 | freqs = get_freqs(5, local_encoder=encoder)
38 | act_freq_scores_list.append(freqs)
39 | # histogram(freqs.log10(), marginal="box", histnorm="percent", title="Frequencies")
40 | wandb.log({
41 | "recons_score": x[0],
42 | "dead": (freqs==0).float().mean().item(),
43 | "below_1e-6": (freqs<1e-6).float().mean().item(),
44 | "below_1e-5": (freqs<1e-5).float().mean().item(),
45 | })
46 | if (i+1) % 30000 == 0:
47 | encoder.save()
48 | wandb.log({"reset_neurons": 0.0})
49 | freqs = get_freqs(50, local_encoder=encoder)
50 | to_be_reset = (freqs<10**(-5.5))
51 | print("Resetting neurons!", to_be_reset.sum())
52 | re_init(to_be_reset, encoder)
53 | finally:
54 | encoder.save()
55 | # %%
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # %%
2 | import os
3 | os.environ["TRANSFORMERS_CACHE"] = "/workspace/cache/"
4 | os.environ["DATASETS_CACHE"] = "/workspace/cache/"
5 | # %%
6 | from neel.imports import *
7 | from neel_plotly import *
8 | import wandb
9 | # %%
10 | import argparse
11 | def arg_parse_update_cfg(default_cfg):
12 | """
13 | Helper function to take in a dictionary of arguments, convert these to command line arguments, look at what was passed in, and return an updated dictionary.
14 |
15 | If in Ipython, just returns with no changes
16 | """
17 | if get_ipython() is not None:
18 | # Is in IPython
19 | print("In IPython - skipped argparse")
20 | return default_cfg
21 | cfg = dict(default_cfg)
22 | parser = argparse.ArgumentParser()
23 | for key, value in default_cfg.items():
24 | if type(value) == bool:
25 | # argparse for Booleans is broken rip. Now you put in a flag to change the default --{flag} to set True, --{flag} to set False
26 | if value:
27 | parser.add_argument(f"--{key}", action="store_false")
28 | else:
29 | parser.add_argument(f"--{key}", action="store_true")
30 |
31 | else:
32 | parser.add_argument(f"--{key}", type=type(value), default=value)
33 | args = parser.parse_args()
34 | parsed_args = vars(args)
35 | cfg.update(parsed_args)
36 | print("Updated config")
37 | print(json.dumps(cfg, indent=2))
38 | return cfg
39 | default_cfg = {
40 | "seed": 49,
41 | "batch_size": 4096,
42 | "buffer_mult": 384,
43 | "lr": 1e-4,
44 | "num_tokens": int(2e9),
45 | "l1_coeff": 3e-4,
46 | "beta1": 0.9,
47 | "beta2": 0.99,
48 | "dict_mult": 32,
49 | "seq_len": 128,
50 | "enc_dtype":"fp32",
51 | "remove_rare_dir": False,
52 | "model_name": "gelu-2l",
53 | "site": "mlp_out",
54 | "layer": 0,
55 | "device": "cuda:0"
56 | }
57 | site_to_size = {
58 | "mlp_out": 512,
59 | "post": 2048,
60 | "resid_pre": 512,
61 | "resid_mid": 512,
62 | "resid_post": 512,
63 | }
64 |
65 | cfg = arg_parse_update_cfg(default_cfg)
66 | def post_init_cfg(cfg):
67 | cfg["model_batch_size"] = cfg["batch_size"] // cfg["seq_len"] * 16
68 | cfg["buffer_size"] = cfg["batch_size"] * cfg["buffer_mult"]
69 | cfg["buffer_batches"] = cfg["buffer_size"] // cfg["seq_len"]
70 | cfg["act_name"] = utils.get_act_name(cfg["site"], cfg["layer"])
71 | cfg["act_size"] = site_to_size[cfg["site"]]
72 | cfg["dict_size"] = cfg["act_size"] * cfg["dict_mult"]
73 | cfg["name"] = f"{cfg['model_name']}_{cfg['layer']}_{cfg['dict_size']}_{cfg['site']}"
74 | post_init_cfg(cfg)
75 | pprint.pprint(cfg)
76 | # %%
77 |
78 | SEED = cfg["seed"]
79 | GENERATOR = torch.manual_seed(SEED)
80 | DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
81 | np.random.seed(SEED)
82 | random.seed(SEED)
83 | torch.set_grad_enabled(True)
84 |
85 | model = HookedTransformer.from_pretrained(cfg["model_name"]).to(DTYPES[cfg["enc_dtype"]]).to(cfg["device"])
86 |
87 | n_layers = model.cfg.n_layers
88 | d_model = model.cfg.d_model
89 | n_heads = model.cfg.n_heads
90 | d_head = model.cfg.d_head
91 | d_mlp = model.cfg.d_mlp
92 | d_vocab = model.cfg.d_vocab
93 | # %%
94 | @torch.no_grad()
95 | def get_acts(tokens, batch_size=1024):
96 | _, cache = model.run_with_cache(tokens, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"])
97 | acts = cache[cfg["act_name"]]
98 | acts = acts.reshape(-1, acts.shape[-1])
99 | subsample = torch.randperm(acts.shape[0], generator=GENERATOR)[:batch_size]
100 | subsampled_acts = acts[subsample, :]
101 | return subsampled_acts, acts
102 | # sub, acts = get_acts(torch.arange(20).reshape(2, 10), batch_size=3)
103 | # sub.shape, acts.shape
104 | # %%
105 | SAVE_DIR = Path("/workspace/1L-Sparse-Autoencoder/checkpoints")
106 | class AutoEncoder(nn.Module):
107 | def __init__(self, cfg):
108 | super().__init__()
109 | d_hidden = cfg["dict_size"]
110 | l1_coeff = cfg["l1_coeff"]
111 | dtype = DTYPES[cfg["enc_dtype"]]
112 | torch.manual_seed(cfg["seed"])
113 | self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg["act_size"], d_hidden, dtype=dtype)))
114 | self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, cfg["act_size"], dtype=dtype)))
115 | self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
116 | self.b_dec = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype))
117 |
118 | self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
119 |
120 | self.d_hidden = d_hidden
121 | self.l1_coeff = l1_coeff
122 |
123 | self.to(cfg["device"])
124 |
125 | def forward(self, x):
126 | x_cent = x - self.b_dec
127 | acts = F.relu(x_cent @ self.W_enc + self.b_enc)
128 | x_reconstruct = acts @ self.W_dec + self.b_dec
129 | l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
130 | l1_loss = self.l1_coeff * (acts.float().abs().sum())
131 | loss = l2_loss + l1_loss
132 | return loss, x_reconstruct, acts, l2_loss, l1_loss
133 |
134 | @torch.no_grad()
135 | def make_decoder_weights_and_grad_unit_norm(self):
136 | W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
137 | W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
138 | self.W_dec.grad -= W_dec_grad_proj
139 | # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders.
140 | self.W_dec.data = W_dec_normed
141 |
142 | def get_version(self):
143 | version_list = [int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)]
144 | if len(version_list):
145 | return 1+max(version_list)
146 | else:
147 | return 0
148 |
149 | def save(self):
150 | version = self.get_version()
151 | torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt"))
152 | with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f:
153 | json.dump(cfg, f)
154 | print("Saved as version", version)
155 |
156 | @classmethod
157 | def load(cls, version):
158 | cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r")))
159 | pprint.pprint(cfg)
160 | self = cls(cfg=cfg)
161 | self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt")))
162 | return self
163 |
164 | @classmethod
165 | def load_from_hf(cls, version):
166 | """
167 | Loads the saved autoencoder from HuggingFace.
168 |
169 | Version is expected to be an int, or "run1" or "run2"
170 |
171 | version 25 is the final checkpoint of the first autoencoder run,
172 | version 47 is the final checkpoint of the second autoencoder run.
173 | """
174 | if version=="run1":
175 | version = 25
176 | elif version=="run2":
177 | version = 47
178 |
179 | cfg = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json")
180 | pprint.pprint(cfg)
181 | self = cls(cfg=cfg)
182 | self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True))
183 | return self
184 |
185 | # %%
186 |
187 |
188 |
189 | # %%
190 | def shuffle_data(all_tokens):
191 | print("Shuffled data")
192 | return all_tokens[torch.randperm(all_tokens.shape[0])]
193 |
194 | loading_data_first_time = False
195 | if loading_data_first_time:
196 | data = load_dataset("NeelNanda/c4-code-tokenized-2b", split="train", cache_dir="/workspace/cache/")
197 | data.save_to_disk("/workspace/data/c4_code_tokenized_2b.hf")
198 | data.set_format(type="torch", columns=["tokens"])
199 | all_tokens = data["tokens"]
200 | all_tokens.shape
201 |
202 |
203 | all_tokens_reshaped = einops.rearrange(all_tokens, "batch (x seq_len) -> (batch x) seq_len", x=8, seq_len=128)
204 | all_tokens_reshaped[:, 0] = model.tokenizer.bos_token_id
205 | all_tokens_reshaped = all_tokens_reshaped[torch.randperm(all_tokens_reshaped.shape[0])]
206 | torch.save(all_tokens_reshaped, "/workspace/data/c4_code_2b_tokens_reshaped.pt")
207 | else:
208 | # data = datasets.load_from_disk("/workspace/data/c4_code_tokenized_2b.hf")
209 | all_tokens = torch.load("/workspace/data/c4_code_2b_tokens_reshaped.pt")
210 | all_tokens = shuffle_data(all_tokens)
211 |
212 | # %%
213 | class Buffer():
214 | """
215 | This defines a data buffer, to store a bunch of MLP acts that can be used to train the autoencoder. It'll automatically run the model to generate more when it gets halfway empty.
216 | """
217 | def __init__(self, cfg):
218 | self.buffer = torch.zeros((cfg["buffer_size"], cfg["act_size"]), dtype=torch.bfloat16, requires_grad=False).to(cfg["device"])
219 | self.cfg = cfg
220 | self.token_pointer = 0
221 | self.first = True
222 | self.refresh()
223 |
224 | @torch.no_grad()
225 | def refresh(self):
226 | self.pointer = 0
227 | with torch.autocast("cuda", torch.bfloat16):
228 | if self.first:
229 | num_batches = self.cfg["buffer_batches"]
230 | else:
231 | num_batches = self.cfg["buffer_batches"]//2
232 | self.first = False
233 | for _ in range(0, num_batches, self.cfg["model_batch_size"]):
234 | tokens = all_tokens[self.token_pointer:self.token_pointer+self.cfg["model_batch_size"]]
235 | _, cache = model.run_with_cache(tokens, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"])
236 | acts = cache[cfg["act_name"]].reshape(-1, self.cfg["act_size"])
237 |
238 | # print(tokens.shape, acts.shape, self.pointer, self.token_pointer)
239 | self.buffer[self.pointer: self.pointer+acts.shape[0]] = acts
240 | self.pointer += acts.shape[0]
241 | self.token_pointer += self.cfg["model_batch_size"]
242 | # if self.token_pointer > all_tokens.shape[0] - self.cfg["model_batch_size"]:
243 | # self.token_pointer = 0
244 |
245 | self.pointer = 0
246 | self.buffer = self.buffer[torch.randperm(self.buffer.shape[0]).to(cfg["device"])]
247 |
248 | @torch.no_grad()
249 | def next(self):
250 | out = self.buffer[self.pointer:self.pointer+self.cfg["batch_size"]]
251 | self.pointer += self.cfg["batch_size"]
252 | if self.pointer > self.buffer.shape[0]//2 - self.cfg["batch_size"]:
253 | # print("Refreshing the buffer!")
254 | self.refresh()
255 | return out
256 |
257 | # buffer.refresh()
258 | # %%
259 |
260 | # %%
261 | def replacement_hook(mlp_post, hook, encoder):
262 | mlp_post_reconstr = encoder(mlp_post)[1]
263 | return mlp_post_reconstr
264 |
265 | def mean_ablate_hook(mlp_post, hook):
266 | mlp_post[:] = mlp_post.mean([0, 1])
267 | return mlp_post
268 |
269 | def zero_ablate_hook(mlp_post, hook):
270 | mlp_post[:] = 0.
271 | return mlp_post
272 |
273 | @torch.no_grad()
274 | def get_recons_loss(num_batches=5, local_encoder=None):
275 | if local_encoder is None:
276 | local_encoder = encoder
277 | loss_list = []
278 | for i in range(num_batches):
279 | tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]]
280 | loss = model(tokens, return_type="loss")
281 | recons_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(cfg["act_name"], partial(replacement_hook, encoder=local_encoder))])
282 | # mean_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(cfg["act_name"], mean_ablate_hook)])
283 | zero_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(cfg["act_name"], zero_ablate_hook)])
284 | loss_list.append((loss, recons_loss, zero_abl_loss))
285 | losses = torch.tensor(loss_list)
286 | loss, recons_loss, zero_abl_loss = losses.mean(0).tolist()
287 |
288 | print(loss, recons_loss, zero_abl_loss)
289 | score = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
290 | print(f"{score:.2%}")
291 | # print(f"{((zero_abl_loss - mean_abl_loss)/(zero_abl_loss - loss)).item():.2%}")
292 | return score, loss, recons_loss, zero_abl_loss
293 | # print(get_recons_loss())
294 |
295 | # %%
296 | # Frequency
297 | @torch.no_grad()
298 | def get_freqs(num_batches=25, local_encoder=None):
299 | if local_encoder is None:
300 | local_encoder = encoder
301 | act_freq_scores = torch.zeros(local_encoder.d_hidden, dtype=torch.float32).to(cfg["device"])
302 | total = 0
303 | for i in tqdm.trange(num_batches):
304 | tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]]
305 |
306 | _, cache = model.run_with_cache(tokens, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"])
307 | acts = cache[cfg["act_name"]]
308 | acts = acts.reshape(-1, cfg["act_size"])
309 |
310 | hidden = local_encoder(acts)[2]
311 |
312 | act_freq_scores += (hidden > 0).sum(0)
313 | total+=hidden.shape[0]
314 | act_freq_scores /= total
315 | num_dead = (act_freq_scores==0).float().mean()
316 | print("Num dead", num_dead)
317 | return act_freq_scores
318 | # %%
319 | @torch.no_grad()
320 | def re_init(indices, encoder):
321 | new_W_enc = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_enc)))
322 | new_W_dec = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_dec)))
323 | new_b_enc = (torch.zeros_like(encoder.b_enc))
324 | print(new_W_dec.shape, new_W_enc.shape, new_b_enc.shape)
325 | encoder.W_enc.data[:, indices] = new_W_enc[:, indices]
326 | encoder.W_dec.data[indices, :] = new_W_dec[indices, :]
327 | encoder.b_enc.data[indices] = new_b_enc[indices]
--------------------------------------------------------------------------------