├── .gitattributes
├── .gitignore
├── README.md
├── assets
├── french.png
├── german.png
└── scandinavian.png
├── autoencoder
├── autoencoder.py
├── configurator.py
├── feature-browser
│ ├── build_website.py
│ ├── main_page.py
│ └── subpages.py
├── prepare.py
├── resource_loader.py
├── train.py
└── utils
│ ├── __init__.py
│ └── plotting_utils.py
├── reproduction.md
├── requirements.txt
└── transformer
├── README.md
├── config
├── train_gpt2.py
└── train_shakespeare_char.py
├── configurator.py
├── data
├── openwebtext
│ ├── prepare.py
│ └── readme.md
├── shakespeare
│ ├── prepare.py
│ └── readme.md
└── shakespeare_char
│ ├── prepare.py
│ └── readme.md
├── hooked_model.py
├── model.py
└── train.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Override jupyter in Github language stats for more accurate estimate of repo code languages
2 | # reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code
3 | *.ipynb linguist-generated
4 | *.html linguist-generated
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .idea
3 | .ipynb_checkpoints/
4 | .vscode
5 | __pycache__/
6 | *.bin
7 | *.pkl
8 | *.pt
9 | *.pyc
10 | *.sh
11 | *.html
12 | *.arrow
13 | *.ipynb
14 | *.css
15 | *.err
16 | *.out
17 | *.png
18 | input.txt
19 | notes.md
20 | autoencoder.ipynb
21 | blog.md
22 | env/
23 | slurm/
24 | wandb/
25 | notes+papers
26 | notes_exp.md
27 | autoencoder/out/env_3.12/
28 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Towards Monosemanticity: Decomposing Language Models With Dictionary Learning
3 |
4 | This repository reproduces results of [Anthropic's Sparse Dictionary Learning paper](https://transformer-circuits.pub/2023/monosemantic-features/). The codebase is quite rough, but the results are excellent. See the [feature interface](https://shehper.github.io/feature-interface/) to browse through the features learned by the sparse autoencoder. There are improvements to be made (see the [TODOs](#todos) section below), and I will work on them intermittently as I juggle things in life :)
5 |
6 | I trained a 1-layer transformer model from scratch using [nanoGPT](https://github.com/karpathy/nanoGPT) with $d_{\text{model}} = 128$. Then, I trained a sparse autoencoder with $4096$ features on its MLP activations as in [Anthropic's paper](https://transformer-circuits.pub/2023/monosemantic-features/). 93% of the autoencoder neurons were alive, only 5% of which were of ultra-low density. There are several interesting features. For example, there is [a feature for French language](https://shehper.github.io/feature-interface/?page=2011),
7 |
8 |
"""]
174 | logits_text.append("""
""")
175 | logits_text.append("""
Negative Logits
""")
176 | for i in range(10):
177 | token = decode([bottom_logits.indices[feature_id, i].tolist()])
178 | token_html = token.replace('\n', '
⏎')
179 | logits_line = f"""
180 |
181 | {token_html}
182 |
183 | {bottom_logits.values[feature_id, i]:.4f}
184 |
"""
185 | logits_text.append(logits_line)
186 | logits_text.append("""
""")
187 |
188 | logits_text.append("""
""")
189 | logits_text.append("""
Positive Logits
""")
190 | for i in range(10):
191 | token = decode([top_logits.indices[feature_id, i].tolist()])
192 | token_html = token.replace('\n', '
⏎')
193 | logits_line = f"""
194 |
195 | {token_html}
196 |
197 | {top_logits.values[feature_id, i]:.4f}
198 |
"""
199 | logits_text.append(logits_line)
200 | logits_text.append("""
""")
201 | logits_text.append("""
""")
202 | return "".join(logits_text)
203 |
204 | # TODO: merge write_alive_feature_page and write_ultralow_density_feature_page into one single function
205 |
206 | def write_alive_feature_page(feature_id, decode, top_logits, bottom_logits, top_acts_data, sampled_acts_data, dirpath=None):
207 |
208 | print(f'writing feature page for feature # {feature_id}')
209 |
210 | assert isinstance(top_acts_data, TensorDict), "expect top activation data to be presented in a TensorDict"
211 | assert top_acts_data.ndim == 2, "expect top activation data TensorDict to be 2-dimensional, shape: (k, W)"
212 |
213 | assert isinstance(sampled_acts_data, TensorDict), "expect samples activation data to be presented in a TensorDict"
214 | assert sampled_acts_data.ndim == 3, "expect sampled activation data TensorDict to be 3-dimensional, shape: (I, X, W)"
215 |
216 | assert 'tokens' in top_acts_data.keys() and 'feature_acts' in top_acts_data.keys() and \
217 | 'tokens' in sampled_acts_data.keys() and 'feature_acts' in sampled_acts_data.keys(), \
218 | "expect input TensorDicts to have tokens and features_acts keys"
219 |
220 | html_content = []
221 |
222 | # add page_header to the HTML page
223 | html_content.append(write_feature_page_header())
224 | html_content.append("""
225 |
""")
226 |
227 | # add histogram of feature activations, top and bottom logits and logits histogram
228 | html_content.append(include_feature_density_histogram(feature_id, dirpath=dirpath))
229 | html_content.append(include_top_and_bottom_logits(top_logits, bottom_logits, decode, feature_id))
230 | html_content.append(include_logits_histogram(feature_id, dirpath=dirpath))
231 |
232 | # add feature #, and the information that it is an ultralow density neuron
233 | html_content.append(f"""
234 |
Neuron # {feature_id}
""")
235 |
236 | # include a section on top activations
237 | html_content.append("""
238 | Top Activations
239 | """)
240 | html_content.append(write_activations_section(decode, top_acts_data))
241 |
242 | # include a section on sampled activations
243 | I = sampled_acts_data.shape[0] # number of intervals
244 | for i in range(I):
245 | if i < I - 1:
246 | html_content.append(f" Subsample Interval {i}
")
247 | else:
248 | html_content.append(f" Bottom Activations
")
249 | html_content.append(write_activations_section(decode, sampled_acts_data[i]))
250 |
251 | # include the end of the HTML page
252 | html_content.append("