├── .gitignore ├── LICENSE ├── README.md ├── analysis.py ├── rare_freq_dir.pt ├── scratch.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Neel Nanda 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TLDR 2 | 3 | This is an open source replication of [Anthropic's Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features/index.html) paper. The autoencoder was trained on the gelu-1l model in TransformerLens, you can access two trained autoencoders and the model using [this tutorial](https://colab.research.google.com/drive/1u8larhpxy8w4mMsJiSBddNOzFGj7_RTn#scrollTo=MYrIYDEfBtbL). 4 | 5 | # Reading This Codebase 6 | 7 | This is a pretty scrappy training codebase, and won't run from the top. I mostly recommend reading the code and copying snippets. See also [Hoagy Cunningham's Github](https://github.com/HoagyC/sparse_coding). 8 | 9 | * `utils.py` contains various utils to define the Autoencoder, data Buffer and training data. 10 | * Toggle `loading_data_first_time` to True to load and process the text data used to run the model and generate acts 11 | * `train.py` is a scrappy training script 12 | * `cfg["remove_rare_dir"]` was an experiment in training an autoencoder whose features were all orthogonal to the shared direction among rare features, those lines of code can be ignored and weren't used for the open source autoencoders. 13 | * There was a bug in the code to set the decoder weights to have unit norm - it makes the gradients orthogonal, but I forgot to *also* set the norm to be 1 again after each gradient update (turns out a vector of unit norm plus a perpendicular vector does not remain unit norm!). I think I have now fixed the bug. 14 | * `analysis.py` is a scrappy set of experiments for exploring the autoencoder. I recommend reading the Colab tutorial instead for something cleaner and better commented. 15 | 16 | Setup Notes: 17 | 18 | * Create data - you'll need to set the flag loading_data_first_time to True in utils.py , note that this downloads the training mix of gelu-1l and if using eg the Pythia models you'll need different data (I recommend https://huggingface.co/datasets/monology/pile-uncopyrighted ) 19 | * A bunch of folders are hard coded to be /workspace/..., change this for your system. 20 | * Create a checkpoints dir in /workspace/1L-Sparse-Autoencoder/checkpoints 21 | 22 | * If you train an autoencoder and want to share the weights, copy the final checkpoints to a new folder, use upload_folder_to_hf to upload to HuggingFace, create your own repo. Run huggingface-cli login to login, and apt-get install git-lfs and then git lfs install -------------------------------------------------------------------------------- /analysis.py: -------------------------------------------------------------------------------- 1 | # %% 2 | new_cfg = { 3 | "layer": 1, 4 | } 5 | from utils import * 6 | torch.set_grad_enabled(False) 7 | cfg.update(new_cfg) 8 | post_init_cfg(cfg) 9 | # %% 10 | tokens = all_tokens[:256] 11 | _, cache = model.run_with_cache(tokens, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"]) 12 | acts = cache[cfg["act_name"]] 13 | acts_flattened = acts.reshape(-1, cfg["act_size"]) 14 | encoder = AutoEncoder.load("gelu-2l_L1_16384_mlp_out_50") 15 | hidden_acts = F.relu((acts_flattened - encoder.b_dec) @ encoder.W_enc + encoder.b_enc) 16 | reconstr = hidden_acts @ encoder.W_dec + encoder.b_dec 17 | l2_loss = (acts_flattened - reconstr).pow(2).sum(-1).mean(0) 18 | l1_loss = encoder.l1_coeff * (hidden_acts.abs().sum()) 19 | print(l2_loss, l1_loss) 20 | # %% 21 | freqs = get_freqs(25, encoder) 22 | # %% 23 | histogram((freqs+10**-6.5).log10(), histnorm="percent", title="Frequencies for Final Checkpoint", xaxis="Freq (Log10)", yaxis="Percent") 24 | 25 | # %% 26 | is_rare = freqs < 1e-4 27 | 28 | is_rare[0] = True 29 | 30 | rare_enc = encoder.W_enc[:, is_rare] 31 | rare_mean = rare_enc.mean(-1) 32 | 33 | cos_with_mean = (rare_mean @ rare_enc) / rare_mean.norm() / rare_enc.norm(dim=0) 34 | histogram(cos_with_mean, histnorm="percent", marginal="box", title="Cosine sim of rare features with mean rare direction", yaxis="percent", xaxis="Cosine Sim") 35 | proj_onto_mean = (rare_mean @ rare_enc) / rare_mean.norm() 36 | histogram(proj_onto_mean, histnorm="percent", marginal="box", title="Projection of rare features onto mean rare direction", yaxis="percent", xaxis="Projection") 37 | 38 | print((cos_with_mean > 0.95).float().mean()) 39 | 40 | scatter(x=proj_onto_mean, y=(freqs[is_rare]+10**-6.5).log10()) 41 | scatter(x=cos_with_mean, y=(freqs[is_rare]+10**-6.5).log10()) 42 | # %% 43 | 44 | feature_df = pd.DataFrame() 45 | feature_kurtosis = scipy.stats.kurtosis(to_numpy(encoder.W_dec.T)) 46 | feature_df["neuron_kurt"] = to_numpy(feature_kurtosis) 47 | feature_kurtosis_enc = scipy.stats.kurtosis(to_numpy(encoder.W_enc)) 48 | feature_df["neuron_kurt_enc"] = to_numpy(feature_kurtosis_enc) 49 | feature_df["is_rare"] = to_numpy(is_rare) 50 | feature_df["freq"] = to_numpy(freqs) 51 | # %% 52 | # encoder2 = AutoEncoder.load(47) 53 | # freqs2 = get_freqs(25, encoder2) 54 | # is_rare2 = freqs2 < 1e-4 55 | # rare_enc2 = encoder2.W_enc[:, is_rare2] 56 | # rare_mean2 = rare_enc2.mean(-1) 57 | # cos_with_mean2 = (rare_mean2 @ rare_enc2) / rare_mean2.norm() / rare_enc2.norm(dim=0) 58 | # histogram(cos_with_mean2, histnorm="percent", marginal="box", title="Cosine sim of rare features with mean rare direction", yaxis="percent", xaxis="Cosine Sim") 59 | # # %% 60 | # rare_mean2 @ rare_mean / rare_mean2.norm() / rare_mean.norm() 61 | # # %% 62 | # def basic_feature_vis(text, feature_index, max_val=0): 63 | # feature_in = encoder.W_enc[:, feature_index] 64 | # feature_bias = encoder.b_enc[feature_index] 65 | # _, cache = model.run_with_cache(text, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"]) 66 | # acts = cache[cfg["act_name"]][0] 67 | # feature_acts = F.relu((acts - encoder.b_dec) @ feature_in + feature_bias) 68 | # if max_val==0: 69 | # max_val = max(1e-7, feature_acts.max().item()) 70 | # # print(max_val) 71 | # # if min_val==0: 72 | # # min_val = min(-1e-7, feature_acts.min().item()) 73 | # return basic_token_vis_make_str(text, feature_acts, max_val) 74 | # def basic_token_vis_make_str(strings, values, max_val=None): 75 | # if not isinstance(strings, list): 76 | # strings = model.to_str_tokens(strings) 77 | # values = to_numpy(values) 78 | # if max_val is None: 79 | # max_val = values.max() 80 | # # if min_val is None: 81 | # # min_val = values.min() 82 | # header_string = f"

Max Range {values.max():.4f} Min Range: {values.min():.4f}

" 83 | # header_string += f"

Set Max Range {max_val:.4f}

" 84 | # # values[values>0] = values[values>0]/ma|x_val 85 | # # values[values<0] = values[values<0]/abs(min_val) 86 | # body_string = nutils.create_html(strings, values, max_value=max_val, return_string=True) 87 | # return header_string + body_string 88 | # display(HTML(basic_token_vis_make_str(tokens[0, :10], acts[0, :10, 7], 0.1))) 89 | # display(HTML(basic_feature_vis("I really like things", 7))) 90 | # %% 91 | # The `with gr.Blocks() as demo:` syntax just creates a variable called demo containing all these components 92 | # import gradio as gr 93 | # try: 94 | # demos[0].close() 95 | # except: 96 | # pass 97 | # demos = [None] 98 | # def make_feature_vis_gradio(batch, pos, feature_id): 99 | # try: 100 | # demos[0].close() 101 | # except: 102 | # pass 103 | # with gr.Blocks() as demo: 104 | # gr.HTML(value=f"Hacky Interactive Neuroscope for gelu-1l") 105 | # # The input elements 106 | # with gr.Row(): 107 | # with gr.Column(): 108 | # text = gr.Textbox(label="Text", value=model.to_string(tokens[batch, 1:pos+1])) 109 | # # Precision=0 makes it an int, otherwise it's a float 110 | # # Value sets the initial default value 111 | # feature_index = gr.Number( 112 | # label="Feature Index", value=feature_id, precision=0 113 | # ) 114 | # # # If empty, these two map to None 115 | # max_val = gr.Number(label="Max Value", value=None) 116 | # # min_val = gr.Number(label="Min Value", value=None) 117 | # inputs = [text, feature_index, max_val] 118 | # with gr.Row(): 119 | # with gr.Column(): 120 | # # The output element 121 | # out = gr.HTML(label="Neuron Acts", value=basic_feature_vis(model.to_string(tokens[batch, 1:pos+1]), feature_id)) 122 | # for inp in inputs: 123 | # inp.change(basic_feature_vis, inputs, out) 124 | # demo.launch(share=True) 125 | # demos[0] = demo 126 | # %% 127 | 128 | # values, indices = (hidden_acts[:, is_rare]>0).float().mean(-1).sort() 129 | # # histogram((acts_flattened @ rare_mean)) 130 | # token_df = nutils.make_token_df(tokens) 131 | # token_df["rare_proj"] = to_numpy(acts_flattened @ rare_mean) 132 | # token_df["frac_rare_active"] = to_numpy((hidden_acts[:, is_rare]>0).float().mean(-1)) 133 | # sorted_token_df = token_df.sort_values("rare_proj", ascending=False) 134 | # for i in range(3): 135 | # b = sorted_token_df.batch.iloc[i] 136 | # p = sorted_token_df.pos.iloc[i] 137 | # print(f"Frac rare features active on final token: {sorted_token_df.frac_rare_active.iloc[i]:.2%}") 138 | # curr_tokens = tokens[b, :p+1] 139 | # values = acts[b, :p+1] @ rare_mean 140 | # nutils.create_html(model.to_str_tokens(curr_tokens), values) 141 | 142 | # %% 143 | # # %% 144 | # sorted_W_enc = encoder.W_enc.abs().sort(dim=0).values 145 | # sorted_W_enc_sq = sorted_W_enc.pow(2) 146 | # sorted_W_enc_sq_sum = sorted_W_enc.pow(2).sum(0) 147 | # for k in [1, 2, 5, 10, 100]: 148 | # feature_df[f"fve_top_{k}"] = to_numpy(sorted_W_enc_sq[-k:, :].sum(0) / sorted_W_enc_sq_sum) 149 | # feature_df["fve_next_9"] = feature_df["fve_top_10"] - feature_df["fve_top_1"] 150 | # feature_df.sort_values("neuron_kurt", ascending=False) 151 | # # %% 152 | # px.histogram(feature_df, x="fve_top_1", histnorm="percent", cumulative=True, marginal="box").show() 153 | # px.histogram(feature_df, x="fve_top_5", histnorm="percent", cumulative=True, marginal="box").show() 154 | # px.histogram(feature_df, x="fve_top_10", histnorm="percent", cumulative=True, marginal="box").show() 155 | # px.histogram(feature_df, x="fve_top_100", marginal="box").show() 156 | # # %% 157 | # px.scatter(feature_df.query("~is_rare"), x="fve_top_1", y="fve_next_9", hover_name=feature_df.query("~is_rare").index, color_continuous_scale="Portland", marginal_x="histogram", marginal_y="histogram", title="Fraction of Squared Sum Explained by Top Neuron vs Next 9 Neurons", opacity=0.2, labels={"fve_top_1":"Frac Explained by Top Neuron", "fve_next_9":"Frac Explained by Next 9 Neurons"}) 158 | # # %% 159 | # sorted_W_dec = encoder.W_dec.T.abs().sort(dim=0).values 160 | # sorted_W_dec_sq = sorted_W_dec.pow(2) 161 | # sorted_W_dec_sq_sum = sorted_W_dec.pow(2).sum(0) 162 | # for k in [1, 2, 5, 10, 100]: 163 | # feature_df[f"fve_top_{k}_dec"] = to_numpy(sorted_W_dec_sq[-k:, :].sum(0) / sorted_W_dec_sq_sum) 164 | # feature_df["fve_next_9_dec"] = feature_df["fve_top_10_dec"] - feature_df["fve_top_1_dec"] 165 | # px.scatter(feature_df.query("~is_rare"), x='fve_top_1', y="fve_top_1_dec").show() 166 | # # %% 167 | # px.scatter(feature_df.query("~is_rare"), x="fve_top_1_dec", y="fve_next_9_dec", hover_name=feature_df.query("~is_rare").index, color_continuous_scale="Portland", marginal_x="histogram", marginal_y="histogram", title="Fraction of Squared Sum Explained by Top Neuron vs Next 9 Neurons", opacity=0.2, labels={"fve_top_1_dec":"Frac Explained by Top Neuron", "fve_next_9_dec":"Frac Explained by Next 9 Neurons"}) 168 | # # %% 169 | # feature_df["1-sparse"] = (feature_df["fve_top_1_dec"]>0.35) & (feature_df["fve_next_9_dec"]<0.1) 170 | # feature_df["10-sparse"] = (feature_df["fve_top_10_dec"]>0.35) & (~feature_df["1-sparse"]) 171 | # print(f"Frac 1 sparse: {feature_df['1-sparse'].mean():.3f}") 172 | # print(f"Frac 10 sparse: {feature_df['10-sparse'].mean():.3f}") 173 | # def f(row): 174 | # if row["1-sparse"]: 175 | # return "1-sparse" 176 | # elif row["10-sparse"]: 177 | # return "10-sparse" 178 | # else: 179 | # return "dense" 180 | # feature_df["sparsity_label"] = feature_df.apply(f, axis=1) 181 | # px.scatter(feature_df.query("~is_rare"), x="fve_top_1_dec", y="fve_next_9_dec", hover_name=feature_df.query("~is_rare").index, color="sparsity_label", color_continuous_scale="Portland", marginal_x="histogram", marginal_y="histogram", title="Fraction of Squared Decoder Sum Explained by Top Neuron vs Next 9 Neurons", opacity=0.2, labels={"fve_top_1_dec":"Frac Explained by Top Neuron", "fve_next_9_dec":"Frac Explained by Next 9 Neurons"}) 182 | # # %% 183 | # px.histogram(feature_df.query("~is_rare"), x="fve_top_1", log_y=True) 184 | # # %% 185 | # feature_df_baseline = pd.DataFrame() 186 | # rand_W_dec = torch.randn_like(encoder.W_dec.T) 187 | # feature_kurtosis = scipy.stats.kurtosis(to_numpy(rand_W_dec)) 188 | # feature_df_baseline["neuron_kurt"] = to_numpy(feature_kurtosis) 189 | # feature_df_baseline["is_rare"] = to_numpy(is_rare) 190 | # # %% 191 | # sorted_W_dec = rand_W_dec.abs().sort(dim=0).values 192 | # sorted_W_dec_sq = sorted_W_dec.pow(2) 193 | # sorted_W_dec_sq_sum = sorted_W_dec.pow(2).sum(0) 194 | # for k in [1, 2, 5, 10, 100]: 195 | # feature_df_baseline[f"fve_top_{k}"] = to_numpy(sorted_W_dec_sq[-k:, :].sum(0) / sorted_W_dec_sq_sum) 196 | # feature_df_baseline["fve_next_9"] = feature_df_baseline["fve_top_10"] - feature_df_baseline["fve_top_1"] 197 | # feature_df_baseline.sort_values("neuron_kurt", ascending=False) 198 | # # %% 199 | # px.histogram(feature_df_baseline, x="fve_top_1", histnorm="percent", cumulative=True, marginal="box").show() 200 | # px.histogram(feature_df_baseline, x="fve_top_5", histnorm="percent", cumulative=True, marginal="box").show() 201 | # px.histogram(feature_df_baseline, x="fve_top_10", histnorm="percent", cumulative=True, marginal="box").show() 202 | # px.histogram(feature_df_baseline, x="fve_top_100", marginal="box").show() 203 | # # %% 204 | # px.scatter(feature_df_baseline.query("~is_rare"), x="fve_top_1", y="fve_next_9", hover_name=feature_df_baseline.query("~is_rare").index, color_continuous_scale="Portland", marginal_x="histogram", marginal_y="histogram", title="") 205 | # # %% 206 | # temp_df = pd.concat([feature_df_baseline.query("~is_rare"), feature_df.query("~is_rare")]) 207 | # temp_df = temp_df.reset_index(drop=True) 208 | # temp_df["category"] = ["baseline"]*(len(temp_df)//2) + ["real"]*(len(temp_df)//2) 209 | # px.histogram(temp_df, "neuron_kurt", color="category", barmode="overlay", marginal="box", range_x=(-5, 50), nbins=5000, title="Neuron Kurtosis (real vs random baseline, clipped at 50)") 210 | 211 | # # %% 212 | # feature_df["enc_dec_sim"] = to_numpy((encoder.W_dec * encoder.W_enc.T).sum(-1) / encoder.W_enc.norm(dim=0) / encoder.W_dec.norm(dim=-1)) 213 | # px.histogram(feature_df.query("~is_rare"), x="enc_dec_sim", title="Encoder Decoder Cosine Sim (non-rare features)") 214 | # # %% 215 | # U, S, Vh = torch.linalg.svd((model.W_out[0])) 216 | # print(U.shape) 217 | # line(S) 218 | # # %% 219 | # W_enc_svd = encoder.W_enc.T @ U 220 | # W_enc_svd_null = W_enc_svd[:, 512:].pow(2).sum(-1) 221 | # W_enc_svd_all = W_enc_svd[:, :].pow(2).sum(-1) 222 | # feature_df["enc_null_frac"] = to_numpy(W_enc_svd_null / W_enc_svd_all) 223 | 224 | # W_dec_svd = encoder.W_dec @ U 225 | # W_dec_svd_null = W_dec_svd[:, 512:].pow(2).sum(-1) 226 | # W_dec_svd_all = W_dec_svd[:, :].pow(2).sum(-1) 227 | # feature_df["dec_null_frac"] = to_numpy(W_dec_svd_null / W_dec_svd_all) 228 | # # px.histogram(feature_df, x="enc_null_frac", color="is_rare", barmode="overlay").show() 229 | # # px.histogram(feature_df, x="dec_null_frac", color="is_rare", barmode="overlay").show() 230 | # # px.scatter(feature_df, x="dec_null_frac", y="enc_null_frac", color="is_rare").show() 231 | # fig = px.histogram(feature_df.query("~is_rare"), x=["enc_null_frac", "dec_null_frac"], barmode="overlay", marginal="box", title="Fraction of feature in W_out null space (non-rare)") 232 | # fig.add_vline(x=0.75, line_dash="dash", line_color="gray") 233 | # # %% 234 | # fig = px.histogram(feature_df.query("~is_rare"), x=["enc_null_frac", "enc_null_frac_baseline"], barmode="overlay", marginal="box", title="Fraction of feature in W_out null space (non-rare)") 235 | # fig.add_vline(x=0.75, line_dash="dash", line_color="gray").show() 236 | # fig = px.histogram(feature_df.query("~is_rare"), x=["dec_null_frac", "dec_null_frac_baseline"], barmode="overlay", marginal="box", title="Fraction of feature in W_out null space (non-rare)") 237 | # fig.add_vline(x=0.75, line_dash="dash", line_color="gray") 238 | 239 | # %% 240 | # print(feature_df["enc_null_frac_baseline"].mean()) 241 | # print(feature_df["enc_null_frac_baseline"].std()) 242 | # print(feature_df["dec_null_frac_baseline"].mean()) 243 | # print(feature_df["dec_null_frac_baseline"].std()) 244 | 245 | # %% 246 | def basic_feature_vis(text, feature_index, max_val=0): 247 | feature_in = encoder.W_enc[:, feature_index] 248 | feature_bias = encoder.b_enc[feature_index] 249 | _, cache = model.run_with_cache(text, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"]) 250 | acts = cache[cfg["act_name"]][0] 251 | feature_acts = F.relu((acts - encoder.b_dec) @ feature_in + feature_bias) 252 | if max_val==0: 253 | max_val = max(1e-7, feature_acts.max().item()) 254 | # print(max_val) 255 | # if min_val==0: 256 | # min_val = min(-1e-7, feature_acts.min().item()) 257 | return basic_token_vis_make_str(text, feature_acts, max_val) 258 | def basic_token_vis_make_str(strings, values, max_val=None): 259 | if not isinstance(strings, list): 260 | strings = model.to_str_tokens(strings) 261 | values = to_numpy(values) 262 | if max_val is None: 263 | max_val = values.max() 264 | # if min_val is None: 265 | # min_val = values.min() 266 | header_string = f"

Max Range {values.max():.4f} Min Range: {values.min():.4f}

" 267 | header_string += f"

Set Max Range {max_val:.4f}

" 268 | # values[values>0] = values[values>0]/ma|x_val 269 | # values[values<0] = values[values<0]/abs(min_val) 270 | body_string = nutils.create_html(strings, values, max_value=max_val, return_string=True) 271 | return header_string + body_string 272 | 273 | # The `with gr.Blocks() as demo:` syntax just creates a variable called demo containing all these components 274 | import gradio as gr 275 | try: 276 | demos[0].close() 277 | except: 278 | pass 279 | demos = [None] 280 | def make_feature_vis_gradio(batch, pos, feature_id): 281 | try: 282 | demos[0].close() 283 | except: 284 | pass 285 | with gr.Blocks() as demo: 286 | gr.HTML(value=f"Hacky Interactive Neuroscope for gelu-1l") 287 | # The input elements 288 | with gr.Row(): 289 | with gr.Column(): 290 | text = gr.Textbox(label="Text", value=model.to_string(tokens[batch, 1:pos+1])) 291 | # Precision=0 makes it an int, otherwise it's a float 292 | # Value sets the initial default value 293 | feature_index = gr.Number( 294 | label="Feature Index", value=feature_id, precision=0 295 | ) 296 | # # If empty, these two map to None 297 | max_val = gr.Number(label="Max Value", value=None) 298 | # min_val = gr.Number(label="Min Value", value=None) 299 | inputs = [text, feature_index, max_val] 300 | with gr.Row(): 301 | with gr.Column(): 302 | # The output element 303 | out = gr.HTML(label="Neuron Acts", value=basic_feature_vis(model.to_string(tokens[batch, 1:pos+1]), feature_id)) 304 | for inp in inputs: 305 | inp.change(basic_feature_vis, inputs, out) 306 | demo.launch(share=True) 307 | demos[0] = demo 308 | # %% 309 | vocab_df = pd.DataFrame({ 310 | "token": np.arange(d_vocab), 311 | "str_token": nutils.process_tokens(np.arange(d_vocab)), 312 | }) 313 | vocab_df["is_upper"] = vocab_df["str_token"].apply(lambda s: s!=s.lower() and s==s.upper()) 314 | vocab_df["is_word"] = vocab_df["str_token"].apply(lambda s: s.replace(nutils.SPACE, "").isalpha()) 315 | vocab_df["has_space"] = vocab_df["str_token"].apply(lambda s: len(s)>0 and s[0]==nutils.SPACE) 316 | vocab_df["is_capital"] = vocab_df["str_token"].apply(lambda s: len(s)>0 and ((s[0]==nutils.SPACE and s[1:]==s[1:].capitalize()) or (s[0]!=nutils.SPACE and s==s.capitalize()))) 317 | vocab_df 318 | # %% 319 | torch.set_grad_enabled(False) 320 | feature_U = (encoder.W_dec) @ model.W_U 321 | vocab_kurts = scipy.stats.kurtosis(to_numpy(feature_U.T)) 322 | feature_df["vocab_kurt"] = vocab_kurts 323 | # %% 324 | tokens = all_tokens[:1024] 325 | _, cache = model.run_with_cache(tokens, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"]) 326 | del _ 327 | acts = cache[cfg["act_name"]] 328 | acts_flattened = acts.reshape(-1, cfg["act_size"]) 329 | hidden_acts = F.relu((acts_flattened - encoder.b_dec) @ encoder.W_enc + encoder.b_enc) 330 | # %% 331 | vocab_has_space = torch.tensor(vocab_df.has_space.values).cuda() 332 | token_has_space = vocab_has_space[tokens.flatten()] 333 | # %% 334 | f_id = 2 335 | print(feature_df.loc[f_id]) 336 | 337 | token_df = nutils.make_token_df(tokens, 8) 338 | token_df["act"] = to_numpy(hidden_acts[:, f_id]) 339 | token_df["active"] = to_numpy(hidden_acts[:, f_id]>0) 340 | token_df = token_df.sort_values("act", ascending=False) 341 | nutils.show_df(token_df.head(50)) 342 | 343 | nutils.show_df(nutils.create_vocab_df(feature_U[f_id]).head(20)) 344 | nutils.show_df(nutils.create_vocab_df(feature_U[f_id]).tail(10)) 345 | px.histogram(to_numpy(feature_U[f_id]), color=vocab_df.has_space, barmode="overlay", histnorm="percent", marginal="box", labels={"title": "wDLA with(out) space", "color":"has_space"}, title="wDLA with(out) space").show() 346 | print("Frac fires with space", (token_has_space[hidden_acts[:, f_id]>0].sum() / len(token_has_space[hidden_acts[:, f_id]>0])).item()) 347 | px.histogram(to_numpy(hidden_acts[:, f_id][hidden_acts[:, f_id]>0]), color=to_numpy(token_has_space[hidden_acts[:, f_id]>0]), barmode="overlay", marginal="box", title="Acts with(out) space").show() 348 | i = 0 349 | make_feature_vis_gradio(token_df.batch.iloc[i], token_df.pos.iloc[i], f_id) 350 | # %% 351 | # Feature 1 352 | 353 | ends_with_ed = vocab_df.str_token.apply(lambda s: s.endswith("ed")).values 354 | px.histogram(to_numpy(feature_U[f_id]), color=ends_with_ed, barmode="overlay", histnorm="percent", marginal="box", labels={"title": "wDLA with(out) space", "color":"has_ed"}, title="wDLA with(out) ed", hover_name=vocab_df.str_token).show() 355 | # %% 356 | # Feature 2 357 | 358 | # %% 359 | 360 | 361 | # hidden = hidden_acts[:, f_id].reshape(tokens.shape) 362 | # ave_firing = (hidden>0).float().mean(-1) 363 | # ave_act = (hidden).mean(-1) 364 | # big_fire_thresh = 0.2 * token_df.act.max() 365 | # ave_act_cond = (hidden).sum(-1) / ((hidden>0).float().sum(-1)+1e-7) 366 | # line([ave_firing, ave_act, ave_act_cond], line_labels=["Freq firing", "Ave act", "Ave act if firing"], title="Per batch summary statistics") 367 | 368 | # argmax_token = tokens.flatten()[hidden.flatten().argmax(-1).cpu()] 369 | # argmax_str_token = model.to_string(argmax_token) 370 | # print(argmax_token, argmax_str_token) 371 | # pos_token_df = token_df[token_df.act>0] 372 | # frac_of_fires_are_top_token = (pos_token_df.str_tokens==argmax_str_token).sum()/len(pos_token_df) 373 | # frac_big_firing_on_top_token = (pos_token_df.query(f"act>{big_fire_thresh}").str_tokens==argmax_str_token).sum()/len(pos_token_df.query(f"act>{big_fire_thresh}")) 374 | # frac_of_top_token_are_fires = (hidden.flatten().cpu()[tokens.flatten()==argmax_token]>0).float().mean().item() 375 | # print(f"{frac_of_fires_are_top_token=:.2%}") 376 | # print(f"{frac_big_firing_on_top_token=:.2%}") 377 | # print(f"{frac_of_top_token_are_fires=:.2%}") 378 | # print(f"Sample size = {(tokens.flatten()==argmax_token).sum().item()}") 379 | 380 | # line([encoder.W_enc[:, f_id], encoder.W_dec[f_id, :]], xaxis="Neuron", title="Weights in the neuron basis", line_labels=["encoder", "decoder"]) 381 | 382 | # %% 383 | 384 | # %% 385 | temp_df = copy.deepcopy(vocab_df) 386 | def f(row): 387 | # print(row) 388 | if row.is_capital and row.has_space: 389 | return "Capital" 390 | elif not row.has_space and row.is_word: 391 | return "Fragment" 392 | else: 393 | return "Other" 394 | temp_df["cond"] = temp_df.apply(f, axis=1) 395 | temp_df["x"] = to_numpy(feature_U[f_id]) 396 | px.histogram(temp_df, x="x", color="is_capital", barmode="overlay", marginal="box", histnorm="percent", hover_name="str_token").show() 397 | px.histogram(temp_df, x="x", color="cond", barmode="overlay", marginal="box", histnorm="percent", hover_name="str_token").show() 398 | # %% 399 | -------------------------------------------------------------------------------- /rare_freq_dir.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neelnanda-io/1L-Sparse-Autoencoder/bcae01328a2f41d24bd4a9160828f2fc22737f75/rare_freq_dir.pt -------------------------------------------------------------------------------- /scratch.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | os.environ["TRANSFORMERS_CACHE"] = "/workspace/cache/" 4 | os.environ["DATASETS_CACHE"] = "/workspace/cache/" 5 | # %% 6 | from neel.imports import * 7 | from neel_plotly import * 8 | import wandb 9 | # %% 10 | import argparse 11 | def arg_parse_update_cfg(default_cfg): 12 | """ 13 | Helper function to take in a dictionary of arguments, convert these to command line arguments, look at what was passed in, and return an updated dictionary. 14 | 15 | If in Ipython, just returns with no changes 16 | """ 17 | if get_ipython() is not None: 18 | # Is in IPython 19 | print("In IPython - skipped argparse") 20 | return default_cfg 21 | cfg = dict(default_cfg) 22 | parser = argparse.ArgumentParser() 23 | for key, value in default_cfg.items(): 24 | if type(value) == bool: 25 | # argparse for Booleans is broken rip. Now you put in a flag to change the default --{flag} to set True, --{flag} to set False 26 | if value: 27 | parser.add_argument(f"--{key}", action="store_false") 28 | else: 29 | parser.add_argument(f"--{key}", action="store_true") 30 | 31 | else: 32 | parser.add_argument(f"--{key}", type=type(value), default=value) 33 | args = parser.parse_args() 34 | parsed_args = vars(args) 35 | cfg.update(parsed_args) 36 | print("Updated config") 37 | print(json.dumps(cfg, indent=2)) 38 | return cfg 39 | default_cfg = { 40 | "seed": 49, 41 | "batch_size": 4096, 42 | "buffer_mult": 384, 43 | "lr": 1e-4, 44 | "num_tokens": int(2e9), 45 | "l1_coeff": 3e-4, 46 | "beta1": 0.9, 47 | "beta2": 0.99, 48 | "dict_mult": 8, 49 | "seq_len": 128, 50 | "d_mlp": 2048, 51 | "enc_dtype":"fp32", 52 | "remove_rare_dir": True 53 | } 54 | cfg = arg_parse_update_cfg(default_cfg) 55 | cfg["model_batch_size"] = cfg["batch_size"] // cfg["seq_len"] * 16 56 | cfg["buffer_size"] = cfg["batch_size"] * cfg["buffer_mult"] 57 | cfg["buffer_batches"] = cfg["buffer_size"] // cfg["seq_len"] 58 | pprint.pprint(cfg) 59 | # %% 60 | 61 | SEED = cfg["seed"] 62 | GENERATOR = torch.manual_seed(SEED) 63 | DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} 64 | np.random.seed(SEED) 65 | random.seed(SEED) 66 | torch.set_grad_enabled(True) 67 | 68 | model = HookedTransformer.from_pretrained("gelu-1l").to(DTYPES[cfg["enc_dtype"]]) 69 | 70 | n_layers = model.cfg.n_layers 71 | d_model = model.cfg.d_model 72 | n_heads = model.cfg.n_heads 73 | d_head = model.cfg.d_head 74 | d_mlp = model.cfg.d_mlp 75 | d_vocab = model.cfg.d_vocab 76 | # %% 77 | @torch.no_grad() 78 | def get_mlp_acts(tokens, batch_size=1024): 79 | _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0)) 80 | mlp_acts = cache[utils.get_act_name("post", 0)] 81 | mlp_acts = mlp_acts.reshape(-1, d_mlp) 82 | subsample = torch.randperm(mlp_acts.shape[0], generator=GENERATOR)[:batch_size] 83 | subsampled_mlp_acts = mlp_acts[subsample, :] 84 | return subsampled_mlp_acts, mlp_acts 85 | # sub, acts = get_mlp_acts(torch.arange(20).reshape(2, 10), batch_size=3) 86 | # sub.shape, acts.shape 87 | # %% 88 | SAVE_DIR = Path("/workspace/1L-Sparse-Autoencoder/checkpoints") 89 | class AutoEncoder(nn.Module): 90 | def __init__(self, cfg): 91 | super().__init__() 92 | d_hidden = cfg["d_mlp"] * cfg["dict_mult"] 93 | l1_coeff = cfg["l1_coeff"] 94 | dtype = DTYPES[cfg["enc_dtype"]] 95 | torch.manual_seed(cfg["seed"]) 96 | self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_mlp, d_hidden, dtype=dtype))) 97 | self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, d_mlp, dtype=dtype))) 98 | self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype)) 99 | self.b_dec = nn.Parameter(torch.zeros(d_mlp, dtype=dtype)) 100 | 101 | self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) 102 | 103 | self.d_hidden = d_hidden 104 | self.l1_coeff = l1_coeff 105 | 106 | self.to("cuda") 107 | 108 | def forward(self, x): 109 | x_cent = x - self.b_dec 110 | acts = F.relu(x_cent @ self.W_enc + self.b_enc) 111 | x_reconstruct = acts @ self.W_dec + self.b_dec 112 | l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0) 113 | l1_loss = self.l1_coeff * (acts.float().abs().sum()) 114 | loss = l2_loss + l1_loss 115 | return loss, x_reconstruct, acts, l2_loss, l1_loss 116 | 117 | @torch.no_grad() 118 | def remove_parallel_component_of_grads(self): 119 | W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) 120 | W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed 121 | self.W_dec.grad -= W_dec_grad_proj 122 | 123 | def get_version(self): 124 | return 1+max([int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)]) 125 | 126 | def save(self): 127 | version = self.get_version() 128 | torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt")) 129 | with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f: 130 | json.dump(cfg, f) 131 | print("Saved as version", version) 132 | 133 | @classmethod 134 | def load(cls, version): 135 | cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r"))) 136 | pprint.pprint(cfg) 137 | self = cls(cfg=cfg) 138 | self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt"))) 139 | return self 140 | # %% 141 | 142 | 143 | 144 | # %% 145 | 146 | 147 | # l1_coeff = 0.01 148 | # encoder = AutoEncoder(d_mlp*cfg["dict_mult"], l1_coeff=cfg['l1_coeff']).cuda() 149 | # loss, x_reconstruct, acts, l2_loss, l1_loss = encoder(sub) 150 | # print(loss, l2_loss, l1_loss) 151 | 152 | # loss.backward() 153 | # print(encoder.W_dec.grad.norm()) 154 | # encoder.remove_parallel_component_of_grads() 155 | # print(encoder.W_dec.grad.norm()) 156 | # print((encoder.W_dec.grad * encoder.W_dec).sum(-1)) 157 | # %% 158 | # wandb.init(project="autoencoder", entity="neelnanda-io") 159 | # # %% 160 | # c4_urls = [f"https://huggingface.co/datasets/allenai/c4/resolve/main/en/c4-train.{i:0>5}-of-01024.json.gz" for i in range(901, 950)] 161 | 162 | # dataset = load_dataset("json", data_files=c4_urls, split="train") 163 | 164 | # dataset_name="c4" 165 | # dataset.save_to_disk(f"/workspace/data/{dataset_name}_text.hf") 166 | # # %% 167 | # print(dataset) 168 | 169 | # from transformer_lens.utils import tokenize_and_concatenate 170 | 171 | # tokenizer = model.tokenizer 172 | 173 | # tokens = tokenize_and_concatenate(dataset, tokenizer, streaming=False, num_proc=20, max_length=128) 174 | # tokens.save_to_disk(f"/workspace/data/{dataset_name}_tokens.hf") 175 | # %% 176 | def shuffle_data(all_tokens): 177 | print("Shuffled data") 178 | return all_tokens[torch.randperm(all_tokens.shape[0])] 179 | 180 | loading_data_first_time = False 181 | if loading_data_first_time: 182 | data = load_dataset("NeelNanda/c4-code-tokenized-2b", split="train") 183 | data.save_to_disk("/workspace/data/c4_code_tokenized_2b.hf") 184 | data.set_format(type="torch", columns=["tokens"]) 185 | all_tokens = data["tokens"] 186 | all_tokens.shape 187 | 188 | 189 | all_tokens_reshaped = einops.rearrange(all_tokens, "batch (x seq_len) -> (batch x) seq_len", x=8, seq_len=128) 190 | all_tokens_reshaped[:, 0] = model.tokenizer.bos_token_id 191 | all_tokens_reshaped = all_tokens_reshaped[torch.randperm(all_tokens_reshaped.shape[0])] 192 | torch.save(all_tokens_reshaped, "/workspace/data/c4_code_2b_tokens_reshaped.pt") 193 | else: 194 | # data = datasets.load_from_disk("/workspace/data/c4_code_tokenized_2b.hf") 195 | all_tokens = torch.load("/workspace/data/c4_code_2b_tokens_reshaped.pt") 196 | all_tokens = shuffle_data(all_tokens) 197 | 198 | # %% 199 | class Buffer(): 200 | def __init__(self, cfg): 201 | self.buffer = torch.zeros((cfg["buffer_size"], cfg["d_mlp"]), dtype=torch.bfloat16, requires_grad=False).cuda() 202 | self.cfg = cfg 203 | self.token_pointer = 0 204 | self.first = True 205 | self.refresh() 206 | 207 | @torch.no_grad() 208 | def refresh(self): 209 | self.pointer = 0 210 | with torch.autocast("cuda", torch.bfloat16): 211 | if self.first: 212 | num_batches = self.cfg["buffer_batches"] 213 | else: 214 | num_batches = self.cfg["buffer_batches"]//2 215 | self.first = False 216 | for _ in range(0, num_batches, self.cfg["model_batch_size"]): 217 | tokens = all_tokens[self.token_pointer:self.token_pointer+self.cfg["model_batch_size"]] 218 | _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0)) 219 | mlp_acts = cache[utils.get_act_name("post", 0)].reshape(-1, self.cfg["d_mlp"]) 220 | # print(tokens.shape, mlp_acts.shape, self.pointer, self.token_pointer) 221 | self.buffer[self.pointer: self.pointer+mlp_acts.shape[0]] = mlp_acts 222 | self.pointer += mlp_acts.shape[0] 223 | self.token_pointer += self.cfg["model_batch_size"] 224 | # if self.token_pointer > all_tokens.shape[0] - self.cfg["model_batch_size"]: 225 | # self.token_pointer = 0 226 | 227 | self.pointer = 0 228 | self.buffer = self.buffer[torch.randperm(self.buffer.shape[0]).cuda()] 229 | 230 | @torch.no_grad() 231 | def next(self): 232 | out = self.buffer[self.pointer:self.pointer+self.cfg["batch_size"]] 233 | self.pointer += self.cfg["batch_size"] 234 | if self.pointer > self.buffer.shape[0]//2 - self.cfg["batch_size"]: 235 | # print("Refreshing the buffer!") 236 | self.refresh() 237 | return out 238 | 239 | # buffer.refresh() 240 | # %% 241 | 242 | # %% 243 | def replacement_hook(mlp_post, hook, encoder): 244 | mlp_post_reconstr = encoder(mlp_post)[1] 245 | return mlp_post_reconstr 246 | 247 | def mean_ablate_hook(mlp_post, hook): 248 | mlp_post[:] = mlp_post.mean([0, 1]) 249 | return mlp_post 250 | 251 | def zero_ablate_hook(mlp_post, hook): 252 | mlp_post[:] = 0. 253 | return mlp_post 254 | 255 | @torch.no_grad() 256 | def get_recons_loss(num_batches=5, local_encoder=None): 257 | if local_encoder is None: 258 | local_encoder = encoder 259 | loss_list = [] 260 | for i in range(num_batches): 261 | tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]] 262 | loss = model(tokens, return_type="loss") 263 | recons_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replacement_hook, encoder=local_encoder))]) 264 | # mean_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), mean_ablate_hook)]) 265 | zero_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), zero_ablate_hook)]) 266 | loss_list.append((loss, recons_loss, zero_abl_loss)) 267 | losses = torch.tensor(loss_list) 268 | loss, recons_loss, zero_abl_loss = losses.mean(0).tolist() 269 | 270 | print(loss, recons_loss, zero_abl_loss) 271 | score = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss)) 272 | print(f"{score:.2%}") 273 | # print(f"{((zero_abl_loss - mean_abl_loss)/(zero_abl_loss - loss)).item():.2%}") 274 | return score, loss, recons_loss, zero_abl_loss 275 | # print(get_recons_loss()) 276 | 277 | # %% 278 | # Frequency 279 | @torch.no_grad() 280 | def get_freqs(num_batches=25, local_encoder=None): 281 | if local_encoder is None: 282 | local_encoder = encoder 283 | act_freq_scores = torch.zeros(local_encoder.d_hidden, dtype=torch.float32).cuda() 284 | total = 0 285 | for i in tqdm.trange(num_batches): 286 | tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]] 287 | 288 | _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0)) 289 | mlp_acts = cache[utils.get_act_name("post", 0)] 290 | mlp_acts = mlp_acts.reshape(-1, d_mlp) 291 | 292 | hidden = local_encoder(mlp_acts)[2] 293 | 294 | act_freq_scores += (hidden > 0).sum(0) 295 | total+=hidden.shape[0] 296 | act_freq_scores /= total 297 | num_dead = (act_freq_scores==0).float().mean() 298 | print("Num dead", num_dead) 299 | return act_freq_scores 300 | # %% 301 | @torch.no_grad() 302 | def re_init(indices, encoder): 303 | new_W_enc = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_enc))) 304 | new_W_dec = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_dec))) 305 | new_b_enc = (torch.zeros_like(encoder.b_enc)) 306 | print(new_W_dec.shape, new_W_enc.shape, new_b_enc.shape) 307 | encoder.W_enc.data[:, indices] = new_W_enc[:, indices] 308 | encoder.W_dec.data[indices, :] = new_W_dec[indices, :] 309 | encoder.b_enc.data[indices] = new_b_enc[indices] 310 | # %% 311 | encoder = AutoEncoder(cfg) 312 | buffer = Buffer(cfg) 313 | evil_dir = torch.load("/workspace/1L-Sparse-Autoencoder/evil_dir.pt") 314 | evil_dir.requires_grad = False 315 | # %% 316 | try: 317 | wandb.init(project="autoencoder", entity="neelnanda-io") 318 | num_batches = cfg["num_tokens"] // cfg["batch_size"] 319 | # model_num_batches = cfg["model_batch_size"] * num_batches 320 | encoder_optim = torch.optim.Adam(encoder.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"])) 321 | recons_scores = [] 322 | act_freq_scores_list = [] 323 | for i in tqdm.trange(num_batches): 324 | i = i % all_tokens.shape[0] 325 | # tokens = all_tokens[i:i+cfg["model_batch_size"]] 326 | # acts = get_mlp_acts(tokens, batch_size=cfg["batch_size"])[0].detach() 327 | acts = buffer.next() 328 | loss, x_reconstruct, mid_acts, l2_loss, l1_loss = encoder(acts) 329 | loss.backward() 330 | encoder.remove_parallel_component_of_grads() 331 | if cfg["remove_rare_dir"]: 332 | with torch.no_grad(): 333 | encoder.W_enc.grad -= (evil_dir @ encoder.W_enc.grad)[None, :] * evil_dir[:, None] 334 | encoder_optim.step() 335 | encoder_optim.zero_grad() 336 | if cfg["remove_rare_dir"]: 337 | with torch.no_grad(): 338 | encoder.W_enc -= (evil_dir @ encoder.W_enc)[None, :] * evil_dir[:, None] 339 | loss_dict = {"loss": loss.item(), "l2_loss": l2_loss.item(), "l1_loss": l1_loss.item()} 340 | del loss, x_reconstruct, mid_acts, l2_loss, l1_loss, acts 341 | if (i) % 100 == 0: 342 | wandb.log(loss_dict) 343 | print(loss_dict) 344 | if (i) % 1000 == 0: 345 | x = (get_recons_loss()) 346 | print("Reconstruction:", x) 347 | recons_scores.append(x[0]) 348 | freqs = get_freqs(5) 349 | act_freq_scores_list.append(freqs) 350 | # histogram(freqs.log10(), marginal="box", histnorm="percent", title="Frequencies") 351 | wandb.log({ 352 | "recons_score": x[0], 353 | "dead": (freqs==0).float().mean().item(), 354 | "below_1e-6": (freqs<1e-6).float().mean().item(), 355 | "below_1e-5": (freqs<1e-5).float().mean().item(), 356 | }) 357 | if (i+1) % 30000 == 0: 358 | encoder.save() 359 | wandb.log({"reset_neurons": 0.0}) 360 | freqs = get_freqs(50) 361 | to_be_reset = (freqs<10**(-5.5)) 362 | print("Resetting neurons!", to_be_reset.sum()) 363 | re_init(to_be_reset, encoder) 364 | finally: 365 | encoder.save() 366 | 367 | # %% 368 | 369 | # # %% 370 | # acts = get_mlp_acts(all_tokens[i:i+cfg["model_batch_size"]], batch_size=cfg["batch_size"])[0].detach() 371 | # loss, x_reconstruct, mid_acts, l2_loss, l1_loss = encoder(acts) 372 | # # %% 373 | # acts.shape, x_reconstruct.shape 374 | # acts.norm(dim=-1).mean(), x_reconstruct.norm(dim=-1).mean(), (acts - x_reconstruct).norm(dim=-1).mean() 375 | # # %% 376 | # line(encoder.W_enc[:, :20].T) 377 | # line(encoder.W_dec[:20]) 378 | # # %% 379 | # line(mid_acts.mean(0)) 380 | # # %% 381 | freqs = get_freqs(50) 382 | histogram(freqs.log10(), marginal="box", histnorm="percent", title="Frequencies") 383 | freqs_5 = get_freqs(5) 384 | scatter(x=freqs_5.log10(), y=freqs.log10()) 385 | # %% 386 | 387 | # %% 388 | (freqs<10**(-5.5)).float().mean() 389 | # %% 390 | @torch.no_grad() 391 | def re_init(indices, encoder): 392 | new_W_enc = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_enc))) 393 | new_W_dec = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_dec))) 394 | new_b_enc = (torch.zeros_like(encoder.b_enc)) 395 | print(new_W_dec.shape, new_W_enc.shape, new_b_enc.shape) 396 | encoder.W_enc.data[:, indices] = new_W_enc[:, indices] 397 | encoder.W_dec.data[indices, :] = new_W_dec[indices, :] 398 | encoder.b_enc.data[indices] = new_b_enc[indices] 399 | freqs = get_freqs(50) 400 | to_be_reset = (freqs<10**(-5.5)) 401 | re_init(to_be_reset, encoder) 402 | # %% 403 | x = (get_recons_loss()) 404 | print("Reconstruction:", x) 405 | # recons_scores.append(x[0]) 406 | freqs = get_freqs(5) 407 | # act_freq_scores_list.append(freqs) 408 | histogram((freqs+10**(-6.5)).log10(), marginal="box", histnorm="percent", title="Frequencies") 409 | 410 | # %% 411 | 412 | # %% 413 | enc2 = AutoEncoder.load(5) 414 | tokens = all_tokens[:32] 415 | acts = get_mlp_acts(tokens, batch_size=1)[1].detach() 416 | # acts = buffer.next() 417 | loss, x_reconstruct, mid_acts, l2_loss, l1_loss = enc2(acts) 418 | print(loss, l2_loss, l1_loss) 419 | # %% 420 | freqs = (mid_acts>0).float().mean(0) 421 | feature_df = pd.DataFrame({"freqs": to_numpy(freqs), "log_freq":to_numpy((freqs).log10())}) 422 | feature_df[feature_df["log_freq"]>-5] 423 | # %% 424 | f_id = 18 425 | token_df = nutils.make_token_df(tokens) 426 | token_df["act"] = to_numpy(mid_acts[:, f_id]) 427 | token_df["active"] = to_numpy(mid_acts[:, f_id]>0) 428 | nutils.show_df(token_df.sort_values("act", ascending=False).head(100)) 429 | 430 | # %% 431 | line(token_df.groupby("batch")["act"].mean()) 432 | line(token_df.groupby("batch")["active"].mean()) 433 | # %%1 434 | act_freq_scores_list = [] 435 | encoders = {} 436 | checkpoints = [25, 22, 21, 18, 15, 12, 9] 437 | for i in checkpoints: 438 | print(i) 439 | encoders[i] = AutoEncoder.load(i) 440 | freqs = get_freqs(20, encoders[i]) 441 | act_freq_scores_list.append(freqs) 442 | histogram((freqs+10**(-6.5)).log10(), marginal="box", histnorm="percent", title=f"Frequencies for checkpoint {i}") 443 | # %% 444 | def num_tokens_per_checkpoint(c): 445 | if c==25: 446 | # return "2000M" 447 | return 2000 448 | else: 449 | return int(((c - 8) * 30000 * 4096)/1e6) 450 | line(x=list(range(26)), y=[num_tokens_per_checkpoint(c) for c in range(26)], title="Number of tokens per checkpoint") 451 | freqs = torch.stack(act_freq_scores_list).flatten() 452 | temp_df = pd.DataFrame({"freqs": to_numpy(freqs), "log_freq":to_numpy((freqs+10**(-6.5)).log10()), 453 | "checkpoint": [c for c in checkpoints for _ in range(encoders[25].d_hidden)], 454 | "million_tokens": [num_tokens_per_checkpoint(c) for c in checkpoints for _ in range(encoders[25].d_hidden)], 455 | }) 456 | px.histogram(temp_df, color="million_tokens", x="log_freq", barmode="overlay", marginal="box", histnorm="percent", title="Frequencies for checkpoints") 457 | # %% 458 | scatter(x=temp_df.query("checkpoint==21").log_freq, y=temp_df.query("checkpoint==22").log_freq, marginal_x="box", marginal_y="box", title="Frequencies for checkpoints 21 and 22", include_diag=True, xaxis=f"{num_tokens_per_checkpoint(21)}M", yaxis=f"{num_tokens_per_checkpoint(22)}M") 459 | scatter(x=temp_df.query("checkpoint==21").log_freq, y=temp_df.query("checkpoint==25").log_freq, marginal_x="box", marginal_y="box", title="Frequencies for checkpoints 21 and 25", include_diag=True, xaxis=f"{num_tokens_per_checkpoint(21)}M", yaxis=f"{num_tokens_per_checkpoint(25)}M") 460 | scatter(x=temp_df.query("checkpoint==22").log_freq, y=temp_df.query("checkpoint==25").log_freq, marginal_x="box", marginal_y="box", title="Frequencies for checkpoints 22 and 25", include_diag=True, xaxis=f"{num_tokens_per_checkpoint(22)}M", yaxis=f"{num_tokens_per_checkpoint(25)}M") 461 | # %% 462 | torch.set_grad_enabled(False) 463 | # %% 464 | tokens = all_tokens[:256] 465 | _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0)) 466 | mlp_acts = cache[utils.get_act_name("post", 0)] 467 | mlp_acts_flattened = mlp_acts.reshape(-1, cfg["d_mlp"]) 468 | encoder = AutoEncoder.load(25) 469 | hidden_acts = F.relu((mlp_acts_flattened - encoder.b_dec) @ encoder.W_enc + encoder.b_enc) 470 | mlp_reconstr = hidden_acts @ encoder.W_dec + encoder.b_dec 471 | l2_loss = (mlp_acts_flattened - mlp_reconstr).pow(2).sum(-1).mean(0) 472 | l1_loss = encoder.l1_coeff * (hidden_acts.abs().sum()) 473 | print(l2_loss, l1_loss) 474 | # %% 475 | freqs = get_freqs(25, encoder) 476 | # %% 477 | histogram((freqs+10**-6.5).log10(), histnorm="percent", title="Frequencies for Final Checkpoint", xaxis="Freq (Log10)", yaxis="Percent") 478 | 479 | # %% 480 | is_rare = freqs < 1e-4 481 | 482 | 483 | # %% 484 | 485 | 486 | # %% 487 | 488 | def replace_mlp_post(mlp_post, hook, replacement): 489 | mlp_post[:] = replacement 490 | return mlp_post 491 | recons_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replace_mlp_post, replacement=mlp_reconstr.reshape(mlp_acts.shape)))]) 492 | zero_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replace_mlp_post, replacement=torch.zeros_like(mlp_acts)))]) 493 | normal_loss = model(tokens, return_type="loss") 494 | mean_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replace_mlp_post, replacement=mlp_acts.mean(0, keepdim=True).mean(1, keepdim=True)))]) 495 | print(f"{recons_loss.item()=}") 496 | print(f"{zero_abl_loss.item()=}") 497 | print(f"{normal_loss.item()=}") 498 | print(f"{mean_loss.item()=}") 499 | # %% 500 | new_losses = [] 501 | min_freq_list = [] 502 | for thresh in [-6.5, -5, -4.5, -4.4, -4.3, -4.2, -4.1, -4, -3, -2.5, -2, -1.5, -1, 0]: 503 | indices = freqs >= 10**thresh 504 | replacement = hidden_acts[:, indices] @ encoder.W_dec[indices, :] + encoder.b_dec 505 | new_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replace_mlp_post, replacement=replacement.reshape(mlp_acts.shape)))]) 506 | new_losses.append(new_loss.item()) 507 | min_freq_list.append(thresh) 508 | new_losses = np.array(new_losses) 509 | line(x=min_freq_list, y=new_losses, title="Loss vs minimum frequency") 510 | line(x=min_freq_list, y=(zero_abl_loss.item() - new_losses)/(zero_abl_loss.item() - normal_loss.item()), title="Scaled Loss vs minimum frequency final checkpoint", yaxis="% Loss Recovered", xaxis="Log freq floor") 511 | # %% 512 | # encoder2 = AutoEncoder.load(21) 513 | # hidden_acts = F.relu((mlp_acts_flattened - encoder2.b_dec) @ encoder2.W_enc + encoder2.b_enc) 514 | # mlp_reconstr = hidden_acts @ encoder2.W_dec + encoder2.b_dec 515 | # l2_loss = (mlp_acts_flattened - mlp_reconstr).pow(2).sum(-1).mean(0) 516 | # l1_loss = encoder2.l1_coeff * (hidden_acts.abs().sum()) 517 | # print(l2_loss, l1_loss) 518 | 519 | # freqs = get_freqs(25, encoder2) 520 | # histogram((freqs+10**-6.5).log10(), barmode="overlay", marginal="box", histnorm="percent", title="Frequencies for checkpoint 21") 521 | 522 | # %% 523 | new_losses2 = [] 524 | min_freq_list2 = [] 525 | for thresh in [-6.5, -6, -5.5, -5, -4, -3, -2.5, -2, -1.5, -1, 0]: 526 | indices = freqs >= 10**thresh 527 | replacement = hidden_acts[:, indices] @ encoder2.W_dec[indices, :] + encoder2.b_dec 528 | new_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), partial(replace_mlp_post, replacement=replacement.reshape(mlp_acts.shape)))]) 529 | new_losses2.append(new_loss.item()) 530 | min_freq_list2.append(thresh) 531 | new_losses2 = np.array(new_losses2) 532 | line(x=min_freq_list2, y=new_losses2, title="Loss vs minimum frequency") 533 | line(x=min_freq_list2, y=(zero_abl_loss.item() - new_losses2)/(zero_abl_loss.item() - normal_loss.item()), title="Scaled Loss vs minimum frequency checkpoint 21", yaxis="% Loss Recovered", xaxis="Log freq floor") 534 | 535 | # %% 536 | fig = line(x=min_freq_list, y=(zero_abl_loss.item() - new_losses)/(zero_abl_loss.item() - normal_loss.item()), title="Scaled Reconstructed Loss vs minimum frequency checkpoint", yaxis="% Loss Recovered", xaxis="Log freq floor", return_fig=True, line_labels=["Final checkpoint"]) 537 | fig.add_trace(go.Scatter(x=min_freq_list2, y=(zero_abl_loss.item() - new_losses2)/(zero_abl_loss.item() - normal_loss.item()), name="Checkpoint 21")) 538 | # %% 539 | def basic_feature_vis(text, feature_index, max_val=0): 540 | feature_in = encoder.W_enc[:, feature_index] 541 | feature_bias = encoder.b_enc[feature_index] 542 | _, cache = model.run_with_cache(text, stop_at_layer=1, names_filter=utils.get_act_name("post", 0)) 543 | mlp_acts = cache[utils.get_act_name("post", 0)][0] 544 | feature_acts = F.relu((mlp_acts - encoder.b_dec) @ feature_in + feature_bias) 545 | if max_val==0: 546 | max_val = max(1e-7, feature_acts.max().item()) 547 | # print(max_val) 548 | # if min_val==0: 549 | # min_val = min(-1e-7, feature_acts.min().item()) 550 | return basic_token_vis_make_str(text, feature_acts, max_val) 551 | def basic_token_vis_make_str(strings, values, max_val=None): 552 | if not isinstance(strings, list): 553 | strings = model.to_str_tokens(strings) 554 | values = to_numpy(values) 555 | if max_val is None: 556 | max_val = values.max() 557 | # if min_val is None: 558 | # min_val = values.min() 559 | header_string = f"

Max Range {values.max():.4f} Min Range: {values.min():.4f}

" 560 | header_string += f"

Set Max Range {max_val:.4f}

" 561 | # values[values>0] = values[values>0]/ma|x_val 562 | # values[values<0] = values[values<0]/abs(min_val) 563 | body_string = nutils.create_html(strings, values, max_value=max_val, return_string=True) 564 | return header_string + body_string 565 | display(HTML(basic_token_vis_make_str(tokens[0, :10], mlp_acts[0, :10, 7], 0.1))) 566 | display(HTML(basic_feature_vis("I really like food food calories burgers eating is great", 7))) 567 | # %% 568 | # The `with gr.Blocks() as demo:` syntax just creates a variable called demo containing all these components 569 | import gradio as gr 570 | try: 571 | demos[0].close() 572 | except: 573 | pass 574 | demos = [None] 575 | def make_feature_vis_gradio(batch, pos, feature_id): 576 | try: 577 | demos[0].close() 578 | except: 579 | pass 580 | with gr.Blocks() as demo: 581 | gr.HTML(value=f"Hacky Interactive Neuroscope for gelu-1l") 582 | # The input elements 583 | with gr.Row(): 584 | with gr.Column(): 585 | text = gr.Textbox(label="Text", value=model.to_string(tokens[batch, 1:pos+1])) 586 | # Precision=0 makes it an int, otherwise it's a float 587 | # Value sets the initial default value 588 | feature_index = gr.Number( 589 | label="Feature Index", value=feature_id, precision=0 590 | ) 591 | # # If empty, these two map to None 592 | max_val = gr.Number(label="Max Value", value=None) 593 | # min_val = gr.Number(label="Min Value", value=None) 594 | inputs = [text, feature_index, max_val] 595 | with gr.Row(): 596 | with gr.Column(): 597 | # The output element 598 | out = gr.HTML(label="Neuron Acts", value=basic_feature_vis(model.to_string(tokens[batch, 1:pos+1]), feature_id)) 599 | for inp in inputs: 600 | inp.change(basic_feature_vis, inputs, out) 601 | demo.launch(share=True) 602 | demos[0] = demo 603 | # %% 604 | batch = 0 605 | feature_id = 7 606 | pos = 28 607 | make_feature_vis_gradio(batch, pos, feature_id) 608 | # %% 609 | 610 | # %% 611 | px.scatter(x=to_numpy(encoder.b_enc), y=to_numpy((freqs+10**-5).log10()), trendline="ols", labels={"x":"b_encoder", "y":"log10 freq", "color":"Is Rare"}, color=to_numpy(freqs<10**(-3.5)), title="Encoder bias vs frequency", marginal_x="histogram", marginal_y="histogram").show() 612 | px.scatter(x=to_numpy(encoder.W_enc.norm(dim=0)), y=to_numpy((freqs+10**-5).log10()), trendline="ols", labels={"x":"W_encoder.norm", "y":"log10 freq", "color":"Is Rare"}, color=to_numpy(freqs<10**(-3.5)), title="Encoder norm vs frequency", marginal_x="histogram", marginal_y="histogram").show() 613 | px.scatter(x=to_numpy(encoder.W_dec.norm(dim=-1)), y=to_numpy((freqs+10**-5).log10()), trendline="ols", labels={"x":"W_decoder.norm", "y":"log10 freq", "color":"Is Rare"}, color=to_numpy(freqs<10**(-3.5)), title="Decoder norm vs frequency", marginal_x="histogram", marginal_y="histogram").show() 614 | px.scatter(x=to_numpy(encoder.b_enc * encoder.W_dec.norm(dim=-1)), y=to_numpy((freqs+10**-5).log10()), trendline="ols", labels={"x":"b_encoder * W_dec.norm", "y":"log10 freq", "color":"Is Rare"}, color=to_numpy(freqs<10**(-3.5)), title="Weighted encoder bias vs frequency", marginal_x="histogram", marginal_y="histogram").show() 615 | px.scatter(x=to_numpy(encoder.W_enc.norm(dim=0) * encoder.W_dec.norm(dim=-1)), y=to_numpy((freqs+10**-5).log10()), trendline="ols", labels={"x":"W_enc.norm * W_dec.norm", "y":"log10 freq", "color":"Is Rare"}, color=to_numpy(freqs<10**(-3.5)), title="Encoder norm products vs frequency", marginal_x="histogram", marginal_y="histogram").show() 616 | # %% 617 | import huggingface_hub 618 | from pathlib import Path 619 | def push_to_hub(local_dir): 620 | if isinstance(local_dir, huggingface_hub.Repository): 621 | local_dir = local_dir.local_dir 622 | os.system(f"git -C {local_dir} add .") 623 | os.system(f"git -C {local_dir} commit -m 'Auto Commit'") 624 | os.system(f"git -C {local_dir} push") 625 | 626 | 627 | # move_folder_to_hub("v235_4L512W_solu_wikipedia", "NeelNanda/SoLU_4L512W_Wiki_Finetune", just_final=False) 628 | # def move_folder_to_hub(model_name, repo_name=None, just_final=True, debug=False): 629 | # if repo_name is None: 630 | # repo_name = model_name 631 | # model_folder = CHECKPOINT_DIR / model_name 632 | # repo_folder = CHECKPOINT_DIR / (model_name + "_repo") 633 | # repo_url = huggingface_hub.create_repo(repo_name, exist_ok=True) 634 | # repo = huggingface_hub.Repository(repo_folder, repo_url) 635 | 636 | # for file in model_folder.iterdir(): 637 | # if not just_final or "final" in file.name or "config" in file.name: 638 | # if debug: 639 | # print(file.name) 640 | # file.rename(repo_folder / file.name) 641 | # push_to_hub(repo.local_dir) 642 | def upload_folder_to_hf(folder_path, repo_name=None, debug=False): 643 | folder_path = Path(folder_path) 644 | if repo_name is None: 645 | repo_name = folder_path.name 646 | repo_folder = folder_path.parent / (folder_path.name + "_repo") 647 | repo_url = huggingface_hub.create_repo(repo_name, exist_ok=True) 648 | repo = huggingface_hub.Repository(repo_folder, repo_url) 649 | 650 | for file in folder_path.iterdir(): 651 | if debug: 652 | print(file.name) 653 | file.rename(repo_folder / file.name) 654 | push_to_hub(repo.local_dir) 655 | upload_folder_to_hf("/workspace/1L-Sparse-Autoencoder/checkpoints_copy_2", "sparse_autoencoder", True) 656 | # %% 657 | freqs = (hidden_acts>0).float().mean(0) 658 | feature_df = pd.DataFrame({"freqs": to_numpy(freqs), "log_freq":to_numpy((freqs).log10())}) 659 | feature_df["is_common"] = feature_df["log_freq"]>-3.5 660 | neuron_kurts = scipy.stats.kurtosis(to_numpy(encoder.W_enc)) 661 | feature_U = (encoder.W_dec @ model.W_out[0]) @ model.W_U 662 | vocab_kurts = scipy.stats.kurtosis(to_numpy(feature_U.T)) 663 | feature_df["vocab_kurt"] = vocab_kurts 664 | feature_df["neuron_kurt"] = neuron_kurts 665 | neuron_frac_max = encoder.W_enc.max(dim=0).values / encoder.W_enc.abs().sum(0) 666 | feature_df["neuron_frac_max"] = to_numpy(neuron_frac_max) 667 | # %% 668 | encoder2 = AutoEncoder.load(47) 669 | freqs2 = get_freqs(5, encoder2) 670 | is_common2 = freqs2>10**-3.5 671 | is_common1 = freqs>10**-3.5 672 | cosine_sims = nutils.cos_mat(encoder.W_enc.T, encoder2.W_enc[:, is_common2].T) 673 | max_cosine_sim = cosine_sims.max(-1).values 674 | feature_df["max_cos"] = to_numpy(max_cosine_sim) 675 | feature_df 676 | # %% 677 | px.histogram(feature_df, x="neuron_kurt", marginal="box", color="is_common", histnorm="percent", title="Neuron Kurtosis", barmode="overlay", hover_name=feature_df.index).show() 678 | px.histogram(feature_df, x="neuron_frac_max", marginal="box", color="is_common", histnorm="percent", title="Neuron Frac Max", barmode="overlay", hover_name=feature_df.index).show() 679 | px.histogram(feature_df, x="vocab_kurt", marginal="box", color="is_common", histnorm="percent", title="Vocab Kurtosis", barmode="overlay", hover_name=feature_df.index).show() 680 | px.scatter(feature_df, x="neuron_kurt", y="neuron_frac_max", hover_name=feature_df.index).show() 681 | # %% 682 | top_features = feature_df.sort_values("neuron_kurt", ascending=False).head(20).index.tolist() 683 | line(encoder.W_enc[:, top_features].T, line_labels=top_features, title="Top Features in Neuron Basis", xaxis="Neuron") 684 | # # %% 685 | # f_id = 6 686 | # token_df = nutils.make_token_df(tokens, 8) 687 | # token_df["act"] = to_numpy(hidden_acts[:, f_id]) 688 | # token_df["active"] = to_numpy(hidden_acts[:, f_id]>0) 689 | # token_df = token_df.sort_values("act", ascending=False) 690 | # nutils.show_df(token_df.head(100)) 691 | 692 | # i = 0 693 | # make_feature_vis_gradio(token_df.batch.iloc[i], token_df.pos.iloc[i], f_id) 694 | # # %% 695 | # is_rare = ~feature_df.is_common.values 696 | # U, S, Vh = torch.linalg.svd(encoder.W_enc[:, ~is_rare]) 697 | # line(S, title="Singular Values of common features") 698 | # histogram(U[:, :5], barmode="overlay", title="MLP side singular vectors common") 699 | # histogram(Vh[:, :5], barmode="overlay", title="Feature side singular vectors common") 700 | 701 | # U, S, Vh = torch.linalg.svd(encoder.W_enc[:, is_rare]) 702 | # line(S, title="Singular Values of rare features") 703 | # histogram(U[:, :5], barmode="overlay", title="MLP side singular vectors rare") 704 | # histogram(Vh[:, :5], barmode="overlay", title="Feature side singular vectors rare") 705 | 706 | # token_df = nutils.make_token_df(tokens, 8, 3) 707 | # token_df["rare_ave_feature"] = to_numpy(mlp_acts_flattened @ U[:, 0] * S[0] * 0.01) 708 | # token_df["num_rare_active"] = to_numpy((hidden_acts[:, is_rare]>0).float().mean(-1)) 709 | # token_df["num_com_active"] = to_numpy((hidden_acts[:, ~is_rare]>0).float().mean(-1)) 710 | # token_df["mlp_act_norm"] = to_numpy(mlp_acts_flattened.norm(dim=-1)) 711 | # nutils.show_df(token_df.sort_values("rare_ave_feature", ascending=False).head(50)) 712 | # nutils.show_df(token_df.sort_values("rare_ave_feature", ascending=False).tail(50)) 713 | # svd_logit_lens = U[:, 0] @ model.W_out[0] @ model.W_U 714 | # nutils.show_df(nutils.create_vocab_df(svd_logit_lens).head(20)) 715 | # # %% 716 | # nutils.show_df(feature_df.query("is_common").head(20)) 717 | # %% 718 | torch.set_grad_enabled(False) 719 | f_id = 12 720 | print(feature_df.loc[f_id]) 721 | tokens = all_tokens[:512] 722 | 723 | _, cache = model.run_with_cache(tokens, stop_at_layer=1, names_filter=utils.get_act_name("post", 0)) 724 | mlp_acts = cache[utils.get_act_name("post", 0)] 725 | mlp_acts_flattened = mlp_acts.reshape(-1, cfg["d_mlp"]) 726 | hidden_acts = F.relu((mlp_acts_flattened - encoder.b_dec) @ encoder.W_enc + encoder.b_enc) 727 | 728 | token_df = nutils.make_token_df(tokens, 8) 729 | token_df["act"] = to_numpy(hidden_acts[:, f_id]) 730 | token_df["active"] = to_numpy(hidden_acts[:, f_id]>0) 731 | token_df = token_df.sort_values("act", ascending=False) 732 | nutils.show_df(token_df.head(50)) 733 | hidden = hidden_acts[:, f_id].reshape(tokens.shape) 734 | ave_firing = (hidden>0).float().mean(-1) 735 | ave_act = (hidden).mean(-1) 736 | big_fire_thresh = 0.2 * token_df.act.max() 737 | ave_act_cond = (hidden).sum(-1) / ((hidden>0).float().sum(-1)+1e-7) 738 | line([ave_firing, ave_act, ave_act_cond], line_labels=["Freq firing", "Ave act", "Ave act if firing"], title="Per batch summary statistics") 739 | 740 | argmax_token = tokens.flatten()[hidden.flatten().argmax(-1).cpu()] 741 | argmax_str_token = model.to_string(argmax_token) 742 | print(argmax_token, argmax_str_token) 743 | pos_token_df = token_df[token_df.act>0] 744 | frac_of_fires_are_top_token = (pos_token_df.str_tokens==argmax_str_token).sum()/len(pos_token_df) 745 | frac_big_firing_on_top_token = (pos_token_df.query(f"act>{big_fire_thresh}").str_tokens==argmax_str_token).sum()/len(pos_token_df.query(f"act>{big_fire_thresh}")) 746 | frac_of_top_token_are_fires = (hidden.flatten().cpu()[tokens.flatten()==argmax_token]>0).float().mean().item() 747 | print(f"{frac_of_fires_are_top_token=:.2%}") 748 | print(f"{frac_big_firing_on_top_token=:.2%}") 749 | print(f"{frac_of_top_token_are_fires=:.2%}") 750 | print(f"Sample size = {(tokens.flatten()==argmax_token).sum().item()}") 751 | 752 | line([encoder.W_enc[:, f_id], encoder.W_dec[f_id, :]], xaxis="Neuron", title="Weights in the neuron basis", line_labels=["encoder", "decoder"]) 753 | 754 | nutils.show_df(nutils.create_vocab_df(feature_U[f_id]).head(20)) 755 | nutils.show_df(nutils.create_vocab_df(feature_U[f_id]).tail(10)) 756 | i = 0 757 | make_feature_vis_gradio(token_df.batch.iloc[i], token_df.pos.iloc[i], f_id) 758 | # %% 759 | # %% 760 | strings = [ 761 | "and| they|", 762 | "and| you|", 763 | "and| we|", 764 | "and| it|", 765 | "and| I|", 766 | "and| she|", 767 | "but| they|", 768 | "but| you|", 769 | "but| we|", 770 | "but| it|", 771 | "but| I|", 772 | "but| she|", 773 | "or| they|", 774 | "or| you|", 775 | "or| we|", 776 | "or| it|", 777 | "or| I|", 778 | "or| she|", 779 | # "but| they|", 780 | # "but| you|", 781 | # "but| we|", 782 | # "but| it|", 783 | # "but| I|", 784 | # "but| she|", 785 | ] 786 | token_df["and_pronoun"] = [any(x in c for x in strings) for c in token_df.context] 787 | px.histogram(token_df, x="act", marginal="box", color="and_pronoun", histnorm="percent", title="Neuron activation for pronouns", barmode="overlay", hover_name="context").show() 788 | # %% 789 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # %% 2 | from utils import * 3 | # %% 4 | encoder = AutoEncoder(cfg) 5 | buffer = Buffer(cfg) 6 | # Code used to remove the "rare freq direction", the shared direction among the ultra low frequency features. 7 | # I experimented with removing it and retraining the autoencoder. 8 | if cfg["remove_rare_dir"]: 9 | rare_freq_dir = torch.load("rare_freq_dir.pt") 10 | rare_freq_dir.requires_grad = False 11 | 12 | # %% 13 | try: 14 | wandb.init(project="autoencoder", entity="neelnanda-io") 15 | num_batches = cfg["num_tokens"] // cfg["batch_size"] 16 | # model_num_batches = cfg["model_batch_size"] * num_batches 17 | encoder_optim = torch.optim.Adam(encoder.parameters(), lr=cfg["lr"], betas=(cfg["beta1"], cfg["beta2"])) 18 | recons_scores = [] 19 | act_freq_scores_list = [] 20 | for i in tqdm.trange(num_batches): 21 | i = i % all_tokens.shape[0] 22 | acts = buffer.next() 23 | loss, x_reconstruct, mid_acts, l2_loss, l1_loss = encoder(acts) 24 | loss.backward() 25 | encoder.make_decoder_weights_and_grad_unit_norm() 26 | encoder_optim.step() 27 | encoder_optim.zero_grad() 28 | loss_dict = {"loss": loss.item(), "l2_loss": l2_loss.item(), "l1_loss": l1_loss.item()} 29 | del loss, x_reconstruct, mid_acts, l2_loss, l1_loss, acts 30 | if (i) % 100 == 0: 31 | wandb.log(loss_dict) 32 | print(loss_dict) 33 | if (i) % 1000 == 0: 34 | x = (get_recons_loss(local_encoder=encoder)) 35 | print("Reconstruction:", x) 36 | recons_scores.append(x[0]) 37 | freqs = get_freqs(5, local_encoder=encoder) 38 | act_freq_scores_list.append(freqs) 39 | # histogram(freqs.log10(), marginal="box", histnorm="percent", title="Frequencies") 40 | wandb.log({ 41 | "recons_score": x[0], 42 | "dead": (freqs==0).float().mean().item(), 43 | "below_1e-6": (freqs<1e-6).float().mean().item(), 44 | "below_1e-5": (freqs<1e-5).float().mean().item(), 45 | }) 46 | if (i+1) % 30000 == 0: 47 | encoder.save() 48 | wandb.log({"reset_neurons": 0.0}) 49 | freqs = get_freqs(50, local_encoder=encoder) 50 | to_be_reset = (freqs<10**(-5.5)) 51 | print("Resetting neurons!", to_be_reset.sum()) 52 | re_init(to_be_reset, encoder) 53 | finally: 54 | encoder.save() 55 | # %% -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import os 3 | os.environ["TRANSFORMERS_CACHE"] = "/workspace/cache/" 4 | os.environ["DATASETS_CACHE"] = "/workspace/cache/" 5 | # %% 6 | from neel.imports import * 7 | from neel_plotly import * 8 | import wandb 9 | # %% 10 | import argparse 11 | def arg_parse_update_cfg(default_cfg): 12 | """ 13 | Helper function to take in a dictionary of arguments, convert these to command line arguments, look at what was passed in, and return an updated dictionary. 14 | 15 | If in Ipython, just returns with no changes 16 | """ 17 | if get_ipython() is not None: 18 | # Is in IPython 19 | print("In IPython - skipped argparse") 20 | return default_cfg 21 | cfg = dict(default_cfg) 22 | parser = argparse.ArgumentParser() 23 | for key, value in default_cfg.items(): 24 | if type(value) == bool: 25 | # argparse for Booleans is broken rip. Now you put in a flag to change the default --{flag} to set True, --{flag} to set False 26 | if value: 27 | parser.add_argument(f"--{key}", action="store_false") 28 | else: 29 | parser.add_argument(f"--{key}", action="store_true") 30 | 31 | else: 32 | parser.add_argument(f"--{key}", type=type(value), default=value) 33 | args = parser.parse_args() 34 | parsed_args = vars(args) 35 | cfg.update(parsed_args) 36 | print("Updated config") 37 | print(json.dumps(cfg, indent=2)) 38 | return cfg 39 | default_cfg = { 40 | "seed": 49, 41 | "batch_size": 4096, 42 | "buffer_mult": 384, 43 | "lr": 1e-4, 44 | "num_tokens": int(2e9), 45 | "l1_coeff": 3e-4, 46 | "beta1": 0.9, 47 | "beta2": 0.99, 48 | "dict_mult": 32, 49 | "seq_len": 128, 50 | "enc_dtype":"fp32", 51 | "remove_rare_dir": False, 52 | "model_name": "gelu-2l", 53 | "site": "mlp_out", 54 | "layer": 0, 55 | "device": "cuda:0" 56 | } 57 | site_to_size = { 58 | "mlp_out": 512, 59 | "post": 2048, 60 | "resid_pre": 512, 61 | "resid_mid": 512, 62 | "resid_post": 512, 63 | } 64 | 65 | cfg = arg_parse_update_cfg(default_cfg) 66 | def post_init_cfg(cfg): 67 | cfg["model_batch_size"] = cfg["batch_size"] // cfg["seq_len"] * 16 68 | cfg["buffer_size"] = cfg["batch_size"] * cfg["buffer_mult"] 69 | cfg["buffer_batches"] = cfg["buffer_size"] // cfg["seq_len"] 70 | cfg["act_name"] = utils.get_act_name(cfg["site"], cfg["layer"]) 71 | cfg["act_size"] = site_to_size[cfg["site"]] 72 | cfg["dict_size"] = cfg["act_size"] * cfg["dict_mult"] 73 | cfg["name"] = f"{cfg['model_name']}_{cfg['layer']}_{cfg['dict_size']}_{cfg['site']}" 74 | post_init_cfg(cfg) 75 | pprint.pprint(cfg) 76 | # %% 77 | 78 | SEED = cfg["seed"] 79 | GENERATOR = torch.manual_seed(SEED) 80 | DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} 81 | np.random.seed(SEED) 82 | random.seed(SEED) 83 | torch.set_grad_enabled(True) 84 | 85 | model = HookedTransformer.from_pretrained(cfg["model_name"]).to(DTYPES[cfg["enc_dtype"]]).to(cfg["device"]) 86 | 87 | n_layers = model.cfg.n_layers 88 | d_model = model.cfg.d_model 89 | n_heads = model.cfg.n_heads 90 | d_head = model.cfg.d_head 91 | d_mlp = model.cfg.d_mlp 92 | d_vocab = model.cfg.d_vocab 93 | # %% 94 | @torch.no_grad() 95 | def get_acts(tokens, batch_size=1024): 96 | _, cache = model.run_with_cache(tokens, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"]) 97 | acts = cache[cfg["act_name"]] 98 | acts = acts.reshape(-1, acts.shape[-1]) 99 | subsample = torch.randperm(acts.shape[0], generator=GENERATOR)[:batch_size] 100 | subsampled_acts = acts[subsample, :] 101 | return subsampled_acts, acts 102 | # sub, acts = get_acts(torch.arange(20).reshape(2, 10), batch_size=3) 103 | # sub.shape, acts.shape 104 | # %% 105 | SAVE_DIR = Path("/workspace/1L-Sparse-Autoencoder/checkpoints") 106 | class AutoEncoder(nn.Module): 107 | def __init__(self, cfg): 108 | super().__init__() 109 | d_hidden = cfg["dict_size"] 110 | l1_coeff = cfg["l1_coeff"] 111 | dtype = DTYPES[cfg["enc_dtype"]] 112 | torch.manual_seed(cfg["seed"]) 113 | self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg["act_size"], d_hidden, dtype=dtype))) 114 | self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, cfg["act_size"], dtype=dtype))) 115 | self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype)) 116 | self.b_dec = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype)) 117 | 118 | self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) 119 | 120 | self.d_hidden = d_hidden 121 | self.l1_coeff = l1_coeff 122 | 123 | self.to(cfg["device"]) 124 | 125 | def forward(self, x): 126 | x_cent = x - self.b_dec 127 | acts = F.relu(x_cent @ self.W_enc + self.b_enc) 128 | x_reconstruct = acts @ self.W_dec + self.b_dec 129 | l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0) 130 | l1_loss = self.l1_coeff * (acts.float().abs().sum()) 131 | loss = l2_loss + l1_loss 132 | return loss, x_reconstruct, acts, l2_loss, l1_loss 133 | 134 | @torch.no_grad() 135 | def make_decoder_weights_and_grad_unit_norm(self): 136 | W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True) 137 | W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed 138 | self.W_dec.grad -= W_dec_grad_proj 139 | # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders. 140 | self.W_dec.data = W_dec_normed 141 | 142 | def get_version(self): 143 | version_list = [int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)] 144 | if len(version_list): 145 | return 1+max(version_list) 146 | else: 147 | return 0 148 | 149 | def save(self): 150 | version = self.get_version() 151 | torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt")) 152 | with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f: 153 | json.dump(cfg, f) 154 | print("Saved as version", version) 155 | 156 | @classmethod 157 | def load(cls, version): 158 | cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r"))) 159 | pprint.pprint(cfg) 160 | self = cls(cfg=cfg) 161 | self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt"))) 162 | return self 163 | 164 | @classmethod 165 | def load_from_hf(cls, version): 166 | """ 167 | Loads the saved autoencoder from HuggingFace. 168 | 169 | Version is expected to be an int, or "run1" or "run2" 170 | 171 | version 25 is the final checkpoint of the first autoencoder run, 172 | version 47 is the final checkpoint of the second autoencoder run. 173 | """ 174 | if version=="run1": 175 | version = 25 176 | elif version=="run2": 177 | version = 47 178 | 179 | cfg = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json") 180 | pprint.pprint(cfg) 181 | self = cls(cfg=cfg) 182 | self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True)) 183 | return self 184 | 185 | # %% 186 | 187 | 188 | 189 | # %% 190 | def shuffle_data(all_tokens): 191 | print("Shuffled data") 192 | return all_tokens[torch.randperm(all_tokens.shape[0])] 193 | 194 | loading_data_first_time = False 195 | if loading_data_first_time: 196 | data = load_dataset("NeelNanda/c4-code-tokenized-2b", split="train", cache_dir="/workspace/cache/") 197 | data.save_to_disk("/workspace/data/c4_code_tokenized_2b.hf") 198 | data.set_format(type="torch", columns=["tokens"]) 199 | all_tokens = data["tokens"] 200 | all_tokens.shape 201 | 202 | 203 | all_tokens_reshaped = einops.rearrange(all_tokens, "batch (x seq_len) -> (batch x) seq_len", x=8, seq_len=128) 204 | all_tokens_reshaped[:, 0] = model.tokenizer.bos_token_id 205 | all_tokens_reshaped = all_tokens_reshaped[torch.randperm(all_tokens_reshaped.shape[0])] 206 | torch.save(all_tokens_reshaped, "/workspace/data/c4_code_2b_tokens_reshaped.pt") 207 | else: 208 | # data = datasets.load_from_disk("/workspace/data/c4_code_tokenized_2b.hf") 209 | all_tokens = torch.load("/workspace/data/c4_code_2b_tokens_reshaped.pt") 210 | all_tokens = shuffle_data(all_tokens) 211 | 212 | # %% 213 | class Buffer(): 214 | """ 215 | This defines a data buffer, to store a bunch of MLP acts that can be used to train the autoencoder. It'll automatically run the model to generate more when it gets halfway empty. 216 | """ 217 | def __init__(self, cfg): 218 | self.buffer = torch.zeros((cfg["buffer_size"], cfg["act_size"]), dtype=torch.bfloat16, requires_grad=False).to(cfg["device"]) 219 | self.cfg = cfg 220 | self.token_pointer = 0 221 | self.first = True 222 | self.refresh() 223 | 224 | @torch.no_grad() 225 | def refresh(self): 226 | self.pointer = 0 227 | with torch.autocast("cuda", torch.bfloat16): 228 | if self.first: 229 | num_batches = self.cfg["buffer_batches"] 230 | else: 231 | num_batches = self.cfg["buffer_batches"]//2 232 | self.first = False 233 | for _ in range(0, num_batches, self.cfg["model_batch_size"]): 234 | tokens = all_tokens[self.token_pointer:self.token_pointer+self.cfg["model_batch_size"]] 235 | _, cache = model.run_with_cache(tokens, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"]) 236 | acts = cache[cfg["act_name"]].reshape(-1, self.cfg["act_size"]) 237 | 238 | # print(tokens.shape, acts.shape, self.pointer, self.token_pointer) 239 | self.buffer[self.pointer: self.pointer+acts.shape[0]] = acts 240 | self.pointer += acts.shape[0] 241 | self.token_pointer += self.cfg["model_batch_size"] 242 | # if self.token_pointer > all_tokens.shape[0] - self.cfg["model_batch_size"]: 243 | # self.token_pointer = 0 244 | 245 | self.pointer = 0 246 | self.buffer = self.buffer[torch.randperm(self.buffer.shape[0]).to(cfg["device"])] 247 | 248 | @torch.no_grad() 249 | def next(self): 250 | out = self.buffer[self.pointer:self.pointer+self.cfg["batch_size"]] 251 | self.pointer += self.cfg["batch_size"] 252 | if self.pointer > self.buffer.shape[0]//2 - self.cfg["batch_size"]: 253 | # print("Refreshing the buffer!") 254 | self.refresh() 255 | return out 256 | 257 | # buffer.refresh() 258 | # %% 259 | 260 | # %% 261 | def replacement_hook(mlp_post, hook, encoder): 262 | mlp_post_reconstr = encoder(mlp_post)[1] 263 | return mlp_post_reconstr 264 | 265 | def mean_ablate_hook(mlp_post, hook): 266 | mlp_post[:] = mlp_post.mean([0, 1]) 267 | return mlp_post 268 | 269 | def zero_ablate_hook(mlp_post, hook): 270 | mlp_post[:] = 0. 271 | return mlp_post 272 | 273 | @torch.no_grad() 274 | def get_recons_loss(num_batches=5, local_encoder=None): 275 | if local_encoder is None: 276 | local_encoder = encoder 277 | loss_list = [] 278 | for i in range(num_batches): 279 | tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]] 280 | loss = model(tokens, return_type="loss") 281 | recons_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(cfg["act_name"], partial(replacement_hook, encoder=local_encoder))]) 282 | # mean_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(cfg["act_name"], mean_ablate_hook)]) 283 | zero_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(cfg["act_name"], zero_ablate_hook)]) 284 | loss_list.append((loss, recons_loss, zero_abl_loss)) 285 | losses = torch.tensor(loss_list) 286 | loss, recons_loss, zero_abl_loss = losses.mean(0).tolist() 287 | 288 | print(loss, recons_loss, zero_abl_loss) 289 | score = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss)) 290 | print(f"{score:.2%}") 291 | # print(f"{((zero_abl_loss - mean_abl_loss)/(zero_abl_loss - loss)).item():.2%}") 292 | return score, loss, recons_loss, zero_abl_loss 293 | # print(get_recons_loss()) 294 | 295 | # %% 296 | # Frequency 297 | @torch.no_grad() 298 | def get_freqs(num_batches=25, local_encoder=None): 299 | if local_encoder is None: 300 | local_encoder = encoder 301 | act_freq_scores = torch.zeros(local_encoder.d_hidden, dtype=torch.float32).to(cfg["device"]) 302 | total = 0 303 | for i in tqdm.trange(num_batches): 304 | tokens = all_tokens[torch.randperm(len(all_tokens))[:cfg["model_batch_size"]]] 305 | 306 | _, cache = model.run_with_cache(tokens, stop_at_layer=cfg["layer"]+1, names_filter=cfg["act_name"]) 307 | acts = cache[cfg["act_name"]] 308 | acts = acts.reshape(-1, cfg["act_size"]) 309 | 310 | hidden = local_encoder(acts)[2] 311 | 312 | act_freq_scores += (hidden > 0).sum(0) 313 | total+=hidden.shape[0] 314 | act_freq_scores /= total 315 | num_dead = (act_freq_scores==0).float().mean() 316 | print("Num dead", num_dead) 317 | return act_freq_scores 318 | # %% 319 | @torch.no_grad() 320 | def re_init(indices, encoder): 321 | new_W_enc = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_enc))) 322 | new_W_dec = (torch.nn.init.kaiming_uniform_(torch.zeros_like(encoder.W_dec))) 323 | new_b_enc = (torch.zeros_like(encoder.b_enc)) 324 | print(new_W_dec.shape, new_W_enc.shape, new_b_enc.shape) 325 | encoder.W_enc.data[:, indices] = new_W_enc[:, indices] 326 | encoder.W_dec.data[indices, :] = new_W_dec[indices, :] 327 | encoder.b_enc.data[indices] = new_b_enc[indices] --------------------------------------------------------------------------------