├── LICENSE ├── README.md ├── activations ├── activation_all.py ├── activation_metrics.py ├── activation_probing_dataset.py ├── activation_subset.py ├── common.py ├── quantize.py └── test.py ├── analysis ├── __init__.py ├── load_results.py └── plots │ ├── __init__.py │ ├── common.py │ ├── context_neurons.py │ ├── ngram_stimuli.py │ ├── sparsity_probes.py │ └── weight_stats.py ├── cache_downloads.py ├── config.py ├── experiments ├── __init__.py ├── activations.py ├── common_configs.py ├── inner_loops.py ├── layer_cached_model.py ├── metrics.py ├── probes.py ├── regression_inner_loops.py └── regression_probes.py ├── feature_selection_experiment.py ├── get_activations.py ├── interpretable_neurons ├── pythia-1.4b │ ├── all_neurons_of_interest.csv │ ├── biasxnorm_p20_50_80.csv │ └── top_mono_neurons.csv ├── pythia-160m │ ├── all_neurons_of_interest.csv │ ├── biasxnorm_p20_50_80.csv │ ├── top_compound_words.csv │ └── top_mono_neurons.csv ├── pythia-1b │ ├── all_neurons_of_interest.csv │ ├── biasxnorm_p20_50_80.csv │ ├── compound_superposition.csv │ ├── monosemantic_code_neurons.csv │ ├── top_mono_neurons.csv │ └── wikidata.csv ├── pythia-2.8b │ ├── all_neurons_of_interest.csv │ ├── biasxnorm_p20_50_80.csv │ ├── top_compound_words.csv │ ├── top_mono_neurons.csv │ └── wikidata.csv ├── pythia-410m │ ├── all_neurons_of_interest.csv │ ├── biasxnorm_p20_50_80.csv │ └── top_mono_neurons.csv ├── pythia-6.9b │ ├── all_neurons_of_interest.csv │ ├── biasxnorm_p20_50_80.csv │ ├── monosemantic_distribution_neurons.csv │ ├── top_fact_neurons.csv │ ├── top_mono_neurons.csv │ └── wikidata.csv └── pythia-70m │ ├── all_neurons_of_interest.csv │ ├── biasxnorm_p20_50_80.csv │ ├── monosemantic_language_neurons.csv │ ├── pythia70m_prime_factor_neurons.csv │ ├── top_mono_neurons.csv │ └── wikidata.csv ├── load.py ├── make_feature_datasets.py ├── notebooks ├── Superposition_intuititions.ipynb ├── activation_dev.ipynb ├── basis_alignment.ipynb ├── code_lang_id_results.ipynb ├── counterfact_probing_results.ipynb ├── data_tables.ipynb ├── distribution_id_results.ipynb ├── eos_plot.ipynb ├── ewt_analysis.ipynb ├── factual_neurons.ipynb ├── find_neuron_stimulu_2.ipynb ├── find_neuron_stimulus.ipynb ├── make_codelang_feature_datasets.ipynb ├── make_naturallang_feature_datasets.ipynb ├── mats_results.ipynb ├── monosemantic_investigation.ipynb ├── monosemantic_verification_lang.ipynb ├── natural_language_id_results.ipynb ├── pile_data.ipynb ├── prefix_results.ipynb ├── results_analysis.ipynb ├── suffix_results.ipynb ├── text_features_results.ipynb └── wikidata │ ├── wikidata_ablation_analysis.ipynb │ ├── wikidata_ablation_datasets.ipynb │ ├── wikidata_neurons.ipynb │ └── wikidata_results.ipynb ├── probing_datasets ├── __init__.py ├── common.py ├── counterfact.py ├── distribution_id.py ├── ewt.py ├── language_id.py ├── latex.py ├── multitoken_supervised.py ├── neuron_stimulus.py ├── ngrams.py ├── pile_test.py ├── position.py ├── spacy_supervised.py ├── token_supervised.py └── wikidata.py ├── probing_experiment.py ├── requirements.txt ├── run_ablation.py ├── save_weight_statistics.py ├── scripts ├── activations │ ├── layer_1_20_50_80_percentile.sh │ ├── layer_1_neuron_0_act.sh │ ├── run_activation_all.sh │ ├── run_activation_metrics.sh │ ├── save_all_ewt.sh │ ├── save_all_neurons_of_interest.sh │ ├── save_compound_words_subset.sh │ ├── save_context_subset.sh │ └── save_fact_neurons.sh ├── experiments │ ├── enumerate_monosemantic_all.sh │ ├── feature_selection.sh │ ├── osp_full.sh │ ├── sparsity_sweep_all.sh │ └── superposition_compound_words.sh ├── lr_hparam_tuning.sh ├── make_all_feature_datasets.sh ├── osp_hparam_tuning.sh ├── probe_all_context_features.sh ├── probing_dataset_activations │ ├── make_compound_word_features_apd.sh │ ├── make_context_features_apd.sh │ ├── make_ewt_feature_apd.sh │ ├── make_latex_feature_apd.sh │ ├── make_position_feature_apd.sh │ ├── make_text_features_apd.sh │ └── make_wikidata_apd.sh ├── run.sh ├── run_code_lang_id_experiment.sh ├── run_compound_words.sh ├── run_distribution_id.sh ├── run_ewt_fast_probe.sh ├── run_feature_selection.sh ├── run_iterative_pruning.sh ├── run_position_probe.sh ├── run_probe_refactor_test.sh ├── run_sequence_ablation.sh ├── run_superposition_experiment.sh ├── save_weight_statistics.sh └── wikidata │ ├── activations.sh │ ├── dataset.sh │ ├── neurons.sh │ ├── probe.sh │ └── table.sh └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Wes Gurnee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # sparse-probing 2 | Code repository for [Finding Neurons in a Haystack: Case Studies with Sparse Probing](https://arxiv.org/abs/2305.01610) 3 | 4 | Pardon our mess. The basic core of sparse probing can be implemented very easily with just sklearn applied to a dataset of activations acquired with raw Pytorch hooks or TransformerLens. This repository is almost all experimental infrastructure and analysis specific to our set up of datasets and compute (slurm). 5 | 6 | See [this](https://github.com/wesg52/llm-context-neurons/tree/main) repository for a minimal replication of finding context neurons. 7 | 8 | ## Organization 9 | We expect most people to simply be interested in a large list of relevant neurons, available as CSVs within `interpretable_neurons/`. Note these are for the Pythia V0 models, which have since been updated on HuggingFace. 10 | 11 | Our top level scripts for saving activations and running probing experiments can be count in `get_activations.py` and `probing_experiment.py`. All of command line argument configurations can be viewed in the `experiments/` directory, which contain all of the slurm scripts we used to run our experiments. 12 | 13 | `probing_datasets/` contain the modules required to make and prepare all of our feature datasets. We recommend simply downloading them from [dropbox](https://www.dropbox.com/sh/cr2vw4owv6dkw7t/AADvkDwJYKyYDC56q1S8dAS_a?dl=0). 14 | 15 | Analysis and plotting code is distributed within individual notebooks and `analysis/`. 16 | 17 | 18 | ## Instructions for reproducing 19 | Note that our full experiments generate well over 1 TB of data and require substantial GPU and CPU time. 20 | 21 | ### Getting started 22 | Create virtual environment and install required packages 23 | ``` 24 | git clone https://github.com/wesg52/sparse-probing-paper.git 25 | cd sparse-probing 26 | pip install virtualenv 27 | python -m venv sparprob 28 | source sparprob/bin/activate 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | Acquire Gurobi [license](https://www.gurobi.com/features/academic-named-user-license/). Free for academics. Make sure you are on campus wifi (you may also need to seperately install [grbgetkey](https://support.gurobi.com/hc/en-us/articles/360059842732)). 33 | 34 | ### Environment variables 35 | To enable running our code in many different environments we use environemnt variables to specify the paths for all data input and output. For examples 36 | ``` 37 | export RESULTS_DIR=/Users/wesgurnee/Documents/mechint/sparse_probing/sparse-probing/results 38 | export FEATURE_DATASET_DIR=/Users/wesgurnee/Documents/mechint/sparse_probing/sparse-probing/feature_datasets 39 | export TRANSFORMERS_CACHE=/Users/wesgurnee/Documents/mechint/sparse_probing/sparse-probing/downloads 40 | export HF_DATASETS_CACHE=/Users/wesgurnee/Documents/mechint/sparse_probing/sparse-probing/downloads 41 | export HF_HOME=/Users/wesgurnee/Documents/mechint/sparse_probing/sparse-probing/downloads 42 | ``` 43 | 44 | 45 | ## Cite us 46 | If you found our work helpful, please cite our paper: 47 | ``` 48 | @article{gurnee2023finding, 49 | title={Finding Neurons in a Haystack: Case Studies with Sparse Probing}, 50 | author={Gurnee, Wes and Nanda, Neel and Pauly, Matthew and Harvey, Katherine and Troitskii, Dmitrii and Bertsimas, Dimitris}, 51 | journal={arXiv preprint arXiv:2305.01610}, 52 | year={2023} 53 | } 54 | ``` 55 | -------------------------------------------------------------------------------- /activations/activation_all.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from datasets.arrow_dataset import Dataset 5 | from einops import rearrange 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | from transformer_lens import HookedTransformer 9 | 10 | from activations.common import (get_hook_name, load_json, load_tensor, 11 | save_json, save_tensor) 12 | 13 | 14 | @torch.no_grad() 15 | def get_full_activation_tensor( 16 | model: HookedTransformer, 17 | dataset: Dataset, 18 | device=None, 19 | batch_size=16, 20 | save_postactivation=False, 21 | verbose=False, 22 | layers=None, 23 | positions=None, 24 | flatten_and_ignore_padding=False 25 | ): 26 | ''' 27 | Collect activations for all examples in a dataset 28 | 29 | Returns a (n_layer, n_neuron, n_sequence, seq_len) tensor of activations 30 | ''' 31 | if device is None: 32 | device = "cuda" if torch.cuda.is_available() else "cpu" 33 | if layers is None: 34 | layers = list(range(model.cfg.n_layers)) 35 | if positions is None: 36 | positions = list(range(len(dataset[0]['tokens']))) 37 | 38 | # make hooks 39 | layer_names = [get_hook_name(lix, save_postactivation) for lix in layers] 40 | layer_ix = {name: lix for lix, name in enumerate(layer_names)} 41 | 42 | # layer x sequence_dim x position x neuron 43 | n_seqs, ctx_len = dataset['tokens'].shape 44 | activation_shape = (len(layers), n_seqs, len(positions), model.cfg.d_mlp) 45 | activations = torch.zeros(activation_shape, dtype=torch.float16) 46 | 47 | batch_num = 0 # nonlocal variable to include in hooks 48 | 49 | def save_hook(tensor, hook): 50 | nonlocal batch_num 51 | nonlocal batch_size 52 | offset = batch_num * batch_size 53 | layer = layer_ix[hook.name] 54 | batch_act = tensor.detach().cpu()[:, positions, :].to(torch.float16) 55 | activations[layer, offset:offset + batch_size, :, :] = batch_act 56 | 57 | for name in layer_names: 58 | model.add_hook(name, save_hook) 59 | 60 | # iterate over dataset 61 | # TODO: can this loop be put into its own function and reused across the activations experiments? 62 | dataloader = DataLoader( 63 | dataset['tokens'], batch_size=batch_size, shuffle=False) 64 | for batch in tqdm(dataloader, disable=not verbose): 65 | model.forward(batch.to(device), return_type=None) 66 | batch_num += 1 67 | model.reset_hooks() 68 | 69 | # to layer, neuron, sequence, position 70 | activations = rearrange(activations, 'l s p n -> l n s p') 71 | if flatten_and_ignore_padding: 72 | # layer, neuron, (sequence, position) 73 | activations = activations[:, :, dataset['tokens'] > 1] 74 | 75 | return activations 76 | 77 | 78 | def save_full_activation_tensor(experiment_dir, activations, metadata, output_precision=16): 79 | ''' 80 | Save activations and metadata to disk 81 | ''' 82 | os.makedirs(experiment_dir, exist_ok=True) 83 | # TODO(wesg): add 8bit option 84 | dtype = torch.float32 if output_precision == 32 else torch.float16 85 | save_tensor(os.path.join(experiment_dir, 'activations.pt'), 86 | activations.to(dtype)) 87 | save_json(os.path.join(experiment_dir, 'metadata.json'), metadata) 88 | 89 | 90 | def load_activation_all(experiment_dir): 91 | ''' 92 | Load activations and metadata from disk 93 | ''' 94 | activations = load_tensor(os.path.join(experiment_dir, 'activations.pt')) 95 | metadata = load_json(os.path.join(experiment_dir, 'metadata.json')) 96 | return activations, metadata 97 | -------------------------------------------------------------------------------- /activations/activation_metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from datasets.arrow_dataset import Dataset 5 | from einops import rearrange, repeat 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | from functools import partial 9 | from functorch import vmap 10 | from transformer_lens import HookedTransformer 11 | 12 | from activations.common import ( 13 | get_hook_name, load_json, load_tensor, 14 | save_json, save_tensor) 15 | 16 | 17 | @torch.no_grad() 18 | def get_activations_hist(activations, hist_min, hist_max, n_bin): 19 | ''' 20 | Compute independent histograms for each neuron in the activations tensor. 21 | ''' 22 | n_layer, n_neuron, _ = activations.shape 23 | 24 | # layer, neuron, activations (= batch_size * seq_len) 25 | all_activations_by_neuron = rearrange( 26 | activations, 'l n a -> (l n) a') 27 | clamped_activations = torch.clamp( 28 | all_activations_by_neuron, hist_min + 1e-6, hist_max - 1e-6 29 | ) 30 | 31 | bin_edges = torch.linspace(hist_min, hist_max, n_bin+1) 32 | binned_histogram = partial(torch.histogram, bins=bin_edges) 33 | vectorized_histogram = vmap(binned_histogram) 34 | hist_by_neuron, _ = vectorized_histogram(clamped_activations) 35 | 36 | # layer, neuron, bin 37 | return rearrange( 38 | hist_by_neuron.to(torch.long), 39 | '(l n) b -> l n b', 40 | l=n_layer, n=n_neuron, b=n_bin 41 | ) 42 | 43 | 44 | @torch.no_grad() 45 | def get_activations_top_k(activations, top_k_values, top_k_indices, step, batch_size, seq_len): 46 | ''' 47 | Compute top k most activating examples for each neuron over both the current 48 | batch and previous batches. Modifies top_k_values and top_k_indices in place. 49 | ''' 50 | # TODO: collect top k for every bin in the histogram 51 | n_layer, n_neuron, top_k = top_k_values.shape 52 | cur_batch_size = activations.shape[2] / seq_len 53 | 54 | all_values = torch.cat([top_k_values, activations], dim=2) 55 | batch_indices = repeat( 56 | int(step*batch_size*seq_len) + 57 | torch.arange(int(seq_len*cur_batch_size)), 58 | 'x -> l n x', l=n_layer, n=n_neuron) 59 | all_indices = torch.concat([top_k_indices, batch_indices], dim=2) 60 | new_indices = torch.empty((n_layer, n_neuron, top_k), dtype=torch.long) 61 | 62 | torch.topk(all_values, top_k, dim=2, sorted=True, 63 | out=(top_k_values, new_indices)) 64 | torch.gather(all_indices, 2, new_indices, out=top_k_indices) 65 | 66 | 67 | @torch.no_grad() 68 | def get_activation_metrics( 69 | model: HookedTransformer, 70 | dataset: Dataset, 71 | device=None, 72 | batch_size=16, 73 | top_k=30, 74 | n_bin=100, 75 | hist_min=-10, 76 | hist_max=10, 77 | save_postactivation=False, 78 | verbose=False, 79 | ): 80 | ''' 81 | For each neuron find the top k most activating examples and activation histogram over the provided dataset. 82 | 83 | Returns 84 | top_k_seqix: (n_layer, n_neuron, top_k) tensor of sequence index for top activating examples 85 | top_k_pos: (n_layer, n_neuron, top_k) tensor of position for top activating examples 86 | bin_counts: (n_layer, n_neuron, n_bin) tensor of histogram bin counts where the final bin captures values greater than hist_max 87 | ''' 88 | n_layer = model.cfg.n_layers 89 | n_neuron = model.cfg.d_mlp 90 | seq_len = len(dataset[0]['tokens']) 91 | 92 | if device is None: 93 | device = "cuda" if torch.cuda.is_available() else "cpu" 94 | 95 | bin_counts = torch.zeros((n_layer, n_neuron, n_bin), dtype=torch.long) 96 | top_k_values = torch.full( 97 | (n_layer, n_neuron, top_k), float('-inf'), dtype=torch.float32) 98 | top_k_indices = torch.empty((n_layer, n_neuron, top_k), dtype=torch.long) 99 | 100 | # add hooks 101 | mlp_activations = {} 102 | 103 | def save_hook(tensor, hook): 104 | mlp_activations[hook.name] = tensor.detach().cpu() 105 | 106 | layer_names = [get_hook_name(lix, save_postactivation) 107 | for lix in range(n_layer)] 108 | for name in layer_names: 109 | model.add_hook(name, save_hook) 110 | 111 | # iterate over dataset 112 | dataloader = DataLoader( 113 | dataset['tokens'], batch_size=batch_size, shuffle=False) 114 | for step, batch in enumerate(tqdm(dataloader, disable=not verbose)): 115 | model.forward(batch.to(device), return_type=None) 116 | activations = rearrange([mlp_activations[name] for name in layer_names], 117 | 'layer seq token neuron -> layer neuron (seq token)') 118 | 119 | # update running top k and histogram 120 | get_activations_top_k(activations, top_k_values, 121 | top_k_indices, step, batch_size, seq_len) 122 | bin_counts += get_activations_hist(activations, 123 | hist_min, hist_max, n_bin) 124 | 125 | model.reset_hooks() 126 | 127 | top_k_seqix = top_k_indices // seq_len 128 | top_k_pos = top_k_indices % seq_len 129 | return top_k_seqix, top_k_pos, bin_counts 130 | 131 | 132 | def save_activation_metrics(experiment_dir, top_k_seqix, top_k_pos, bin_counts, 133 | metadata, output_precision=16): 134 | ''' 135 | Save the activation metrics results and metadata 136 | ''' 137 | os.makedirs(experiment_dir, exist_ok=True) 138 | # TODO(wesg): precision 139 | save_tensor( 140 | os.path.join(experiment_dir, 'top_k_seqix.pt'), top_k_seqix) 141 | save_tensor( 142 | os.path.join(experiment_dir, 'top_k_pos.pt'), top_k_pos) 143 | save_tensor( 144 | os.path.join(experiment_dir, 'bin_counts.pt'), bin_counts) 145 | 146 | save_json(os.path.join(experiment_dir, 'metadata.json'), metadata) 147 | 148 | 149 | def load_activation_metrics(experiment_dir): 150 | ''' 151 | Load the activation metrics resutls and metadata 152 | ''' 153 | top_k_seqix = load_tensor(os.path.join(experiment_dir, 'top_k_seqix.pt')) 154 | top_k_pos = load_tensor(os.path.join(experiment_dir, 'top_k_pos.pt')) 155 | bin_counts = load_tensor(os.path.join(experiment_dir, 'bin_counts.pt')) 156 | metadata = load_json(os.path.join(experiment_dir, 'metadata.json')) 157 | return top_k_seqix, top_k_pos, bin_counts, metadata -------------------------------------------------------------------------------- /activations/common.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | 5 | import torch 6 | 7 | 8 | def get_hook_name(layerix: int, use_post=False): 9 | return f'blocks.{layerix}.mlp.hook_{"post" if use_post else "pre"}' 10 | 11 | 12 | def time_function(function, *args, **kwargs): 13 | start_time = time.perf_counter() 14 | out = function(*args, **kwargs) 15 | return out, time.perf_counter() - start_time 16 | 17 | 18 | def save_tensor(filename, tensor): 19 | with open(filename, 'wb') as f: 20 | torch.save(tensor, f) 21 | 22 | 23 | def load_tensor(filename): 24 | with open(filename, 'rb') as f: 25 | return torch.load(f) 26 | 27 | 28 | def save_json(filenname, data): 29 | with open(filenname, 'w') as f: 30 | json.dump(data, f) 31 | 32 | 33 | def load_json(filename): 34 | with open(filename, 'r') as f: 35 | return json.load(f) 36 | 37 | 38 | def get_experiment_dir(args): 39 | return os.path.join( 40 | os.environ.get('RESULTS_DIR', 'results'), 41 | args.experiment_type[0], 42 | args.model, 43 | args.feature_dataset, 44 | args.experiment_name if args.experiment_name else args.experiment_type 45 | ) 46 | 47 | 48 | def get_experiment_info_str(args): 49 | info_str = "" 50 | info_str += f'Running activations experiment "{args.experiment_name}" of type "{args.experiment_type}"\n' 51 | info_str += f'\tmodel: {args.model}\n' 52 | info_str += f'\tfeature dataset name: {args.feature_dataset}\n' 53 | info_str += f'\tn_threads: {args.n_threads}, device: {args.device}\n' 54 | info_str += f'\tbatch size: {args.batch_size}, output_precision: {args.output_precision}\n' 55 | return info_str 56 | 57 | 58 | def get_experiment_metadata(args, total_time): 59 | arg_dict = vars(args) 60 | arg_dict['total_time'] = total_time 61 | arg_dict['current_time'] = time.time() 62 | return arg_dict 63 | 64 | 65 | def expand_abs(path): 66 | return os.path.abspath(os.path.expanduser(path)) 67 | -------------------------------------------------------------------------------- /activations/quantize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def quantize_8bit(input): 6 | """Quantize a tensor to a given precision. 7 | 8 | Args: 9 | input (torch.Tensor): The tensor to quantize. 10 | precision (int): The number of bits to quantize to. 11 | 12 | Returns: 13 | torch.Tensor: The quantized tensor. 14 | """ 15 | offset = input.min(axis=0).values 16 | scale = (input.max(axis=0).values - offset) / 255 17 | quant = ((input - offset) / scale).float().round().clamp(0, 18 | 255).to(torch.uint8) 19 | return quant, offset, scale 20 | 21 | 22 | def unquantize_8bit(input, offset, scale): 23 | """Unquantize a tensor to a given precision. 24 | 25 | Args: 26 | input (torch.Tensor): The tensor to quantize. 27 | precision (int): The number of bits to quantize to. 28 | 29 | Returns: 30 | torch.Tensor: The quantized tensor. 31 | """ 32 | return input.to(torch.float16) * scale + offset 33 | -------------------------------------------------------------------------------- /activations/test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from activations.activation_all import get_full_activation_tensor 5 | from activations.activation_metrics import get_activation_metrics 6 | from activations.activation_subset import get_activation_subset 7 | from activations.common import time_function 8 | 9 | 10 | def test_get_activation_metrics( 11 | model, dataset, batch_size=16, top_k=30, n_bin=30, hist_min=-.2, hist_max=2): 12 | ''' 13 | Validate the outputs of the `get_activation_metrics` function 14 | ''' 15 | n_layer = model.cfg.n_layers 16 | n_neuron = model.cfg.d_mlp 17 | 18 | (top_k_seqix, top_k_pos, bin_counts), act_metrics_time = time_function( 19 | get_activation_metrics, model, dataset, batch_size=batch_size, top_k=top_k, 20 | n_bin=n_bin, hist_min=hist_min, hist_max=hist_max, verbose=True) 21 | activations, act_all_time = time_function( 22 | get_full_activation_tensor, model, dataset, verbose=True) 23 | print( 24 | f'get_activation_stats: {act_metrics_time:.3f}, get_mlp_activations: {act_all_time:.3f}') 25 | 26 | # validate top k output 27 | wrong_act_counter = 0 28 | for lix in range(n_layer): 29 | for nix in range(n_neuron): 30 | neuron_top_activations = activations[lix, nix, :, :].flatten().sort( 31 | descending=True)[0][:top_k] 32 | for k in range(top_k): 33 | seqix = top_k_seqix[lix, nix, k] 34 | pos = top_k_pos[lix, nix, k] 35 | activation = activations[lix, nix, seqix, pos] 36 | wrong_act_counter += activation not in neuron_top_activations 37 | print(f'Top k output correct: {wrong_act_counter==0}') 38 | 39 | # validate histogram output 40 | # NOTE: I'm not sure why the histograms don't match exactly, but they are very close. 41 | # I can't find a pattern to the errors other than it probably has something to do 42 | # with values around the bin edges as differences always come in pairs with one bin 43 | # being off by 1 and an adjacent one being off by -1. 44 | # If this is a problem, we can try to figure out what's going on, but for now it 45 | # seems fine to leave it. 46 | bin_counts_compare = torch.zeros( 47 | (n_layer, n_neuron, n_bin), dtype=torch.float32) 48 | bin_edges = torch.cat([torch.linspace(hist_min, hist_max, n_bin), 49 | torch.tensor([float('inf')])]) 50 | for lix in range(n_layer): 51 | for nix in range(n_neuron): 52 | bin_counts_compare[lix, nix, :] += torch.histogram( 53 | activations[lix, nix, :, :].flatten(), bin_edges).hist 54 | prop_diff = 1 - ((bin_counts == bin_counts_compare).sum() / 55 | bin_counts.numel()).item() 56 | print( 57 | f'bin_counts exact match: {(bin_counts == bin_counts_compare).all()}, near match: {prop_diff<1e-5}, prop different: {prop_diff:.7f}') 58 | 59 | 60 | def test_get_activation_subset(model, dataset, batch_size=16, subset_size=10): 61 | ''' 62 | Validate the outputs of the `get_activation_subset function 63 | ''' 64 | n_layer = model.cfg.n_layers 65 | n_neuron = model.cfg.d_mlp 66 | neuron_subset = [(np.random.randint(n_layer), np.random.randint(n_neuron)) 67 | for _ in range(subset_size)] 68 | 69 | activation_subset, subset_time = time_function( 70 | get_activation_subset, model, dataset, neuron_subset, 71 | batch_size=batch_size, verbose=True) 72 | activations, all_time = time_function( 73 | get_full_activation_tensor, model, dataset, batch_size=batch_size, verbose=True) 74 | print( 75 | f'get_activation_subset: {subset_time:.3f}, get_mlp_activations: {all_time:.3f}') 76 | 77 | # validate output 78 | keys_match = set(neuron_subset) == set(activation_subset.keys()) 79 | print(f'Subset keys correct: {keys_match}') 80 | if not keys_match: 81 | return 82 | values_match = True 83 | for neuron in neuron_subset: 84 | lix, nix = neuron 85 | neuron_activations = activations[lix, nix, :, :] 86 | if not (neuron_activations == activation_subset[neuron]).all(): 87 | values_match = False 88 | break 89 | print(f'Subset values correct: {values_match}') 90 | -------------------------------------------------------------------------------- /analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wesg52/sparse-probing-paper/a610e102c6e25a6ef9cc16c3a2abb736aa90849b/analysis/__init__.py -------------------------------------------------------------------------------- /analysis/load_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import pandas as pd 5 | import numpy as np 6 | 7 | 8 | def load_probing_experiment_results(results_dir, experiment_name, model_name, dataset_name, inner_loop, uncollapse_features=False): 9 | result_dir = os.path.join( 10 | results_dir, experiment_name, model_name, dataset_name, inner_loop) 11 | results = {} 12 | config = None 13 | for result_file in os.listdir(result_dir): 14 | if result_file == 'config.json': 15 | config_file = os.path.join(result_dir, result_file) 16 | config = json.load(open(config_file, 'r')) 17 | continue 18 | # example: heuristic_sparsity_sweep.arxiv.pythia-125m.mlp,hook_post.max.0.p 19 | _, feature, _, hook_loc, aggregation, layer, _ = result_file.split('.') 20 | layer = int(layer) 21 | hook_loc = hook_loc.replace(',', '.') 22 | results_dict = pickle.load( 23 | open(os.path.join(result_dir, result_file), 'rb')) 24 | if uncollapse_features: # --save_features_together enabled 25 | for k, v in results_dict.items(): 26 | results[(f'{k}', layer, aggregation, hook_loc)] = v 27 | else: 28 | results[feature, layer, aggregation, hook_loc] = results_dict 29 | return results, config 30 | 31 | 32 | def load_probing_experiment_results_old(results_dir, experiment_name, inner_loop, model_name): 33 | # old version 34 | result_dir = os.path.join( 35 | results_dir, experiment_name, inner_loop, model_name) 36 | results = {} 37 | for result_file in os.listdir(result_dir): 38 | if len(result_file.split('.')) == 5: 39 | _, feature, _, layer, file_type = result_file.split('.') 40 | else: 41 | continue 42 | print(result_file) 43 | _, feature, probe_loc, _, layer, file_type = result_file.split( 44 | '.') 45 | layer = int(layer[1:]) 46 | if feature not in results: 47 | results[feature] = {} 48 | results[feature][layer] = pickle.load( 49 | open(os.path.join(result_dir, result_file), 'rb')) 50 | return results 51 | 52 | 53 | def make_heuristic_probing_results_df(results_dict): 54 | flattened_results = {} 55 | for feature in results_dict: 56 | for layer in results_dict[feature]: 57 | for sparsity in results_dict[feature][layer]: 58 | flattened_results[(feature, layer, sparsity) 59 | ] = results_dict[feature][layer][sparsity] 60 | rdf = pd.DataFrame(flattened_results).T.sort_index().rename_axis( 61 | index=['feature', 'layer', 'k']) 62 | return rdf 63 | 64 | 65 | def collect_monosemantic_results(probing_results): 66 | dfs = {} 67 | for k, result in probing_results.items(): 68 | dfs[k] = pd.DataFrame(result).T 69 | rdf = pd.concat(dfs) # .reset_index() 70 | rdf.rename_axis( 71 | index=['feature', 'layer', 'aggregation', 'hook_loc', 'neuron'], 72 | inplace=True 73 | ) 74 | return rdf.sort_index() 75 | -------------------------------------------------------------------------------- /analysis/plots/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wesg52/sparse-probing-paper/a610e102c6e25a6ef9cc16c3a2abb736aa90849b/analysis/plots/__init__.py -------------------------------------------------------------------------------- /analysis/plots/common.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wesg52/sparse-probing-paper/a610e102c6e25a6ef9cc16c3a2abb736aa90849b/analysis/plots/common.py -------------------------------------------------------------------------------- /analysis/plots/ngram_stimuli.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | 5 | 6 | def make_stimuli_plot_data(neuron_stimuli, activation_df, decoded_vocab): 7 | stimuli_medians = {} 8 | token_stimulus_ordering = {} 9 | ngram_activation_dfs = {} 10 | 11 | for t in neuron_stimuli.keys(): 12 | t_adf = copy.deepcopy(activation_df.query('token==@t')) 13 | t_adf['class_label'] = 'other' # start with other by default 14 | stimulus_order = [] 15 | for stimulus in neuron_stimuli[t]: 16 | n_gram = max(len(p) for p in stimulus) 17 | for i in range(n_gram, 0, -1): 18 | t_adf[f'prefix-{i}'] = activation_df.loc[ 19 | t_adf.index.values - i, 'token'].values 20 | 21 | # use first prefix in the set as class label 22 | stimulus_string = f"{''.join([decoded_vocab[p] for p in stimulus[0]])}" 23 | stimulus_order.append(stimulus_string) 24 | 25 | for prefix in stimulus: 26 | if len(prefix) == 1: 27 | t_adf.loc[t_adf['prefix-1'] == prefix[0], 28 | 'class_label'] = stimulus_string 29 | elif len(prefix) == 2: 30 | t_adf.loc[ 31 | (t_adf['prefix-2'] == prefix[0]) & 32 | (t_adf['prefix-1'] == prefix[1]), 33 | 'class_label' 34 | ] = stimulus_string 35 | elif len(prefix) == 3: 36 | t_adf.loc[ 37 | (t_adf['prefix-3'] == prefix[0]) & 38 | (t_adf['prefix-2'] == prefix[1]) & 39 | (t_adf['prefix-1'] == prefix[2]), 40 | 'class_label' 41 | ] = stimulus_string 42 | else: 43 | raise ValueError( 44 | f"prefix length {len(prefix)} not supported") 45 | 46 | stimulus_order.append('other') 47 | token_stimulus_ordering[t] = stimulus_order 48 | ngram_activation_dfs[t] = t_adf 49 | # used to order the subplots 50 | max_class_median_activation = t_adf.groupby( 51 | 'class_label').activation.median().max() 52 | stimuli_medians[t] = max_class_median_activation 53 | 54 | return ngram_activation_dfs, stimuli_medians, token_stimulus_ordering 55 | 56 | 57 | def make_neuron_stimulus_plot(ngram_activation_dfs, token_ordering, token_stimulus_ordering, decoded_vocab, title=None): 58 | fig, axs = plt.subplots(1, len(token_ordering), figsize=( 59 | 1.3 * len(token_ordering), 4.5), sharey=True) 60 | for ix, t in enumerate(token_ordering): 61 | t_adf = ngram_activation_dfs[t] 62 | token_stimuli_order = token_stimulus_ordering[t] 63 | # see https://stackoverflow.com/questions/46173419/seaborn-change-color-according-to-hue-name 64 | # always want orange to be the last color 65 | palette_dict = { 66 | 2: ["C0", "C1"], 67 | 3: ["C0", "C2", "C1"], 68 | 4: ["C0", "C2", "C3", "C1"], 69 | } 70 | palette = palette_dict[len(token_stimuli_order)] 71 | ax = axs[ix] 72 | sns.boxplot( 73 | t_adf, x='token', y='activation', hue='class_label', 74 | hue_order=token_stimuli_order, palette=palette, ax=ax, 75 | whis=(5, 95), fliersize=1 76 | ) 77 | ax.legend(loc='lower left', prop={'size': 6}, frameon=False) 78 | 79 | # formatting 80 | ax.set_xlabel('') 81 | ax.set_ylabel('pre-activation') 82 | ax.set_xticklabels([f"'{decoded_vocab[t]}'"]) 83 | ax.spines['right'].set_visible(False) 84 | ax.spines['top'].set_visible(False) 85 | ax.grid(axis='y', color='lightgray', linestyle='--', linewidth=0.75) 86 | ax.tick_params(axis='y', which='both', length=0) 87 | if ix > 0: 88 | ax.set_ylabel('') 89 | ax.spines['left'].set_visible(False) 90 | 91 | plt.subplots_adjust(wspace=0, hspace=0) 92 | plt.suptitle(title, y=0.95, fontsize=16) 93 | return ax 94 | 95 | 96 | def make_intro_polysemantic_plot(ngram_activation_dfs, token_ordering, token_stimulus_ordering, decoded_vocab, title=None): 97 | fig, axs = plt.subplots(1, len(token_ordering), figsize=( 98 | 2 * len(token_ordering), 4), sharey=True) 99 | for ix, t in enumerate(token_ordering): 100 | t_adf = ngram_activation_dfs[t] 101 | token_stimuli_order = token_stimulus_ordering[t] 102 | # see https://stackoverflow.com/questions/46173419/seaborn-change-color-according-to-hue-name 103 | # always want orange to be the last color 104 | palette = ["C0", "C2", "C1"] if len( 105 | token_stimuli_order) == 3 else ["C0", "C1"] 106 | ax = axs[ix] 107 | sns.boxplot( 108 | t_adf, x='token', y='activation', hue='class_label', 109 | hue_order=token_stimuli_order, palette=palette, ax=ax, 110 | whis=(5, 95), fliersize=1 111 | ) 112 | ax.legend(loc='lower left', prop={'size': 9}, frameon=False) 113 | 114 | # formatting 115 | ax.set_xlabel('') 116 | ax.set_ylabel('pre-activation') 117 | ax.set_xticklabels([f"'{decoded_vocab[t]}'"]) 118 | ax.spines['right'].set_visible(False) 119 | ax.spines['top'].set_visible(False) 120 | ax.grid(axis='y', color='lightgray', linestyle='--', linewidth=0.75) 121 | ax.tick_params(axis='y', which='both', length=0) 122 | if ix > 0: 123 | ax.set_ylabel('') 124 | ax.spines['left'].set_visible(False) 125 | 126 | plt.subplots_adjust(wspace=0, hspace=0) 127 | plt.ylim(-2.5, 3.9) 128 | plt.suptitle(title, y=0.95) 129 | return axs 130 | -------------------------------------------------------------------------------- /analysis/plots/sparsity_probes.py: -------------------------------------------------------------------------------- 1 | import matplotlib as mpl 2 | import math 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | 7 | # fade (linear interpolate) from color c1 (at mix=0) to c2 (mix=1) 8 | def colorFader(c1='red', c2='blue', mix=0): 9 | c1 = np.array(mpl.colors.to_rgb(c1)) 10 | c2 = np.array(mpl.colors.to_rgb(c2)) 11 | return mpl.colors.to_hex((1-mix)*c1 + mix*c2) 12 | 13 | 14 | def plot_metric_over_sparsity_per_layer(rdf, features=(), metric='test_pr_auc'): 15 | if len(features) == 0: 16 | features = sorted(rdf.index.get_level_values(0).unique()) 17 | 18 | layers = sorted(rdf.index.get_level_values(1).unique()) 19 | ks = sorted(rdf.index.get_level_values(2).unique()) 20 | n_rows = math.ceil(len(layers)/4) 21 | fig, axs = plt.subplots(n_rows, 4, figsize=(10, 2.5*n_rows), sharey=True) 22 | for l in layers: 23 | ax = axs[l//4, l % 4] 24 | for f in features: 25 | perf = rdf.loc[f, l, :][metric] 26 | ax.plot(ks, perf, label=f) 27 | ax.scatter(ks, perf, s=2) 28 | 29 | ax.set_xscale('log') 30 | ax.set_title(f'layer {l}') 31 | if l % 4 == 0: 32 | ax.set_ylabel(metric) 33 | if l//4 == n_rows-1: 34 | ax.set_xlabel('sparsity') 35 | if l == 3: 36 | ax.legend() 37 | plt.tight_layout() 38 | 39 | 40 | def plot_layer_metric_over_sparsity_per_feature(rdf, features=(), metric='test_pr_auc', n_cols=3): 41 | if len(features) == 0: 42 | features = sorted(rdf.index.get_level_values(0).unique()) 43 | 44 | layers = sorted(rdf.index.get_level_values(1).unique()) 45 | ks = sorted(rdf.index.get_level_values(2).unique()) 46 | n_rows = math.ceil(len(features) / n_cols) 47 | fig, axs = plt.subplots( 48 | n_rows, n_cols, figsize=(10, 2.5*n_rows+5), sharey=True) 49 | for f_ix, f in enumerate(features): 50 | ax = axs[f_ix//n_cols, f_ix % n_cols] if n_rows > 1 else axs[f_ix % n_cols] 51 | for l in layers: 52 | perf = rdf.loc[f, l, :][metric] 53 | ax.plot(ks, perf, label=l, color=colorFader(mix=l/len(layers))) 54 | ax.scatter(ks, perf, s=2, color=colorFader(mix=l/len(layers))) 55 | 56 | ax.set_xscale('log') 57 | ax.set_title(f'feature {f}') 58 | if f_ix % n_cols == 0: 59 | ax.set_ylabel(metric) 60 | if f_ix//n_cols == 2: 61 | ax.set_xlabel('sparsity') 62 | if f_ix % n_cols == 0 and f_ix//n_cols == n_rows-1: 63 | ax.legend(ncols=4, bbox_to_anchor=(0, -1, 1, 1), loc='lower left') 64 | plt.tight_layout() 65 | -------------------------------------------------------------------------------- /analysis/plots/weight_stats.py: -------------------------------------------------------------------------------- 1 | import matplotlib.gridspec as gridspec 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import pandas as pd 5 | import seaborn as sns 6 | 7 | 8 | def weight_boxplot(model_stats, m, ax, tail_cutoff=0.5): 9 | in_norm = model_stats[m]['in_norm'] 10 | in_bias = model_stats[m]['in_bias'] 11 | biasxnorm = in_norm * in_bias 12 | 13 | inlier_range = np.percentile(biasxnorm, tail_cutoff), np.percentile( 14 | biasxnorm, 100-tail_cutoff) 15 | 16 | sns.boxplot(data=pd.DataFrame(biasxnorm).T, ax=ax, fliersize=1) 17 | ax.set_xlabel(f'layer ({m})') 18 | ax.set_ylabel('$||W_{in}||_2 b_{in}$') 19 | ax.set_ylim(inlier_range) 20 | ax.axhline(0, color='black', linestyle='--', linewidth=1, alpha=0.5) 21 | 22 | 23 | def all_model_weight_boxplot(model_stats, figsize=(12, 15)): 24 | # You can adjust the width and height as needed 25 | fig = plt.figure(figsize=figsize) 26 | # Create a GridSpec object with 6 rows and 3 columns 27 | gs = gridspec.GridSpec(6, 3) 28 | 29 | # Create the first row plots 30 | ax1 = plt.subplot(gs[0, 0]) 31 | ax2 = plt.subplot(gs[0, 1:]) 32 | 33 | weight_boxplot(model_stats, 'pythia-70m', ax1) 34 | weight_boxplot(model_stats, 'pythia-160m', ax2) 35 | 36 | plot_positions = [ 37 | (1, 0, 1), (1, 1, 2), 38 | (2, 0, 3), (3, 0, 4), 39 | (4, 0, 5), (5, 0, 6) 40 | ] 41 | 42 | # fig, axs = plt.subplots(len(models), 1, figsize=(12, 3 * len(models))) 43 | for m_ix, m in enumerate(model_stats): 44 | if m_ix <= 1: 45 | continue 46 | row, col, plot_num = plot_positions[m_ix - 1] 47 | ax = plt.subplot(gs[row, :]) 48 | weight_boxplot(model_stats, m, ax) 49 | 50 | plt.suptitle('Distribution of $||W_{in}||_2 b_{in}$ by layer') 51 | plt.tight_layout() 52 | 53 | 54 | def plot_normalized_median_norm_bias(models, model_stats, ax=None): 55 | if ax is None: 56 | fig, ax = plt.subplots(1, 1, figsize=(5, 4)) 57 | for model in models[1:]: 58 | in_norm = model_stats[model]['in_norm'] 59 | in_bias = model_stats[model]['in_bias'] 60 | n_layers, n_neurons = in_norm.shape 61 | 62 | sp_score = np.median((in_norm * in_bias), axis=1) / \ 63 | np.max(np.median(np.abs(in_norm * in_bias), axis=1)) 64 | relative_depth = np.arange(n_layers) / (n_layers - 1) 65 | ax.plot(relative_depth, sp_score, label=model.split('-')[-1]) 66 | 67 | ax.legend(ncol=2, loc='lower right', 68 | title='pythia model', fontsize='small') 69 | ax.axhline(0, color='black', linestyle='--', linewidth=1, alpha=0.5) 70 | ax.set_xlabel('relative layer depth') 71 | ax.set_ylabel('normalized median($||W_{in}||_2 b_{in}$)') 72 | ax.set_xlim(-0.005, 1.005) 73 | # turn off top and right spines 74 | ax.spines['top'].set_visible(False) 75 | ax.spines['right'].set_visible(False) 76 | -------------------------------------------------------------------------------- /cache_downloads.py: -------------------------------------------------------------------------------- 1 | # Script for saving models and data to centralized storage on the cluster. 2 | # This script only ever needs to be run one time per model or dataset. 3 | # Use methods in [load.py] to access the downloaded models/datasets. 4 | 5 | # Make sure environment variables are set before running this script 6 | 7 | import argparse 8 | import os 9 | 10 | 11 | def save_model(model_name): 12 | # Just loading the model once will cache it to TRANSFORMERS_CACHE 13 | from transformer_lens import HookedTransformer 14 | HookedTransformer.from_pretrained(model_name, device='cpu') 15 | 16 | 17 | def save_dataset(dataset_name, split): 18 | import datasets 19 | dataset = datasets.load_dataset(dataset_name, split=split) 20 | # cache doesn't work since there is no loader script 21 | # see workaround https://github.com/huggingface/datasets/issues/3547#issuecomment-1252503988 22 | save_path = os.path.join( 23 | os.environ['HF_DATASETS_CACHE'], dataset_name, split) 24 | dataset.save_to_disk(save_path) 25 | 26 | 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument( 30 | '-m', '--model', default=None, help='Name of model from TransformerLens') 31 | parser.add_argument( 32 | '-d', '--dataset', default=None, help='Name of dataset from HF') 33 | parser.add_argument( 34 | '-s', '--split', default=None, help='Name of split for dataset from HF') 35 | 36 | args = vars(parser.parse_args()) 37 | 38 | if args['model'] is not None: 39 | model = args['model'] 40 | print(f'Saving model {model}') 41 | save_model(model) 42 | 43 | if args['dataset'] is not None: 44 | dataset = args['dataset'] 45 | split = args['split'] if args['split'] is not None else 'train' 46 | print(f'Saving split {split} of {dataset}') 47 | save_dataset(dataset, split) 48 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass 3 | 4 | 5 | @dataclass 6 | class ExperimentConfig: 7 | def __init__(self, args, feature_dataset_cfg): 8 | self.experiment_name = args.get('experiment_name') 9 | self.experiment_type = args.get('experiment_type') 10 | self.model_name = args.get('model') 11 | self.dataset_cfg = feature_dataset_cfg 12 | self.feature_dataset = args['feature_dataset'] 13 | self.probe_location = args.get('probe_location', 'mlp.hook_post') 14 | self.activation_aggregation = args.get('activation_aggregation', None) 15 | self.normalize_activations = args.get('normalize_activations', False) 16 | self.seed = args.get('seed', 1) 17 | self.test_set_frac = args.get('test_set_frac', 0.3) 18 | self.batch_size = args.get('batch_size', 16) 19 | self.save_features_together = args.get( 20 | 'save_features_together', False) 21 | self.feature_subset = args.get('feature_subset', '') 22 | self.probe_next_token_feature = args.get( 23 | 'probe_next_token_feature', False) 24 | self.heuristic_feature_selection_method = args.get( 25 | 'heuristic_feature_selection_method', 'mean_dif') 26 | self.max_k = args.get('max_k', 128) 27 | self.osp_upto_k = args.get('osp_upto_k', 5) 28 | self.osp_heuristic_filter_size = args.get( 29 | 'osp_heuristic_filter_size', 50) 30 | self.gurobi_timeout = args.get('gurobi_timeout', 60) 31 | self.gurobi_verbose = args.get('gurobi_verbose', False) 32 | self.iterative_pruning_fixed_k = args.get( 33 | 'iterative_pruning_fixed_k', 5) 34 | self.iterative_pruning_n_prune_steps = args.get( 35 | 'iterative_pruning_n_prune_steps', 10) 36 | self.max_per_class = args.get('max_per_class', -1) 37 | 38 | 39 | @dataclass 40 | class FeatureDatasetConfig: 41 | def __init__( 42 | self, 43 | dataset_name, 44 | tokenizer_name, 45 | ctx_len, 46 | n_sequences, 47 | ): 48 | self.dataset_name = dataset_name 49 | self.tokenizer_name = tokenizer_name 50 | self.ctx_len = ctx_len 51 | self.n_sequences = n_sequences 52 | 53 | def make_dir_name(self): 54 | save_dir = '.'.join([ 55 | self.dataset_name, 56 | self.tokenizer_name, 57 | str(self.ctx_len), 58 | str(self.n_sequences), 59 | ]) 60 | return save_dir 61 | 62 | 63 | def parse_dataset_args(feature_dataset_string): 64 | ds_args = feature_dataset_string.split('.') 65 | feature_collection, tokenizer_name, seq_len, n_seqs = ds_args 66 | feature_dataset_cfg = FeatureDatasetConfig( 67 | feature_collection, 68 | tokenizer_name, 69 | int(seq_len), 70 | int(n_seqs), 71 | ) 72 | return feature_dataset_cfg 73 | -------------------------------------------------------------------------------- /experiments/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wesg52/sparse-probing-paper/a610e102c6e25a6ef9cc16c3a2abb736aa90849b/experiments/__init__.py -------------------------------------------------------------------------------- /experiments/activations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import einops 3 | from torch.utils.data import DataLoader 4 | from datasets import Dataset 5 | 6 | 7 | def save_activation(tensor, hook): 8 | hook.ctx['activation'] = tensor.detach().cpu().to(torch.float16) 9 | 10 | 11 | def process_activation_batch(exp_cfg, batch_activations, step, index_mask=None): 12 | cur_batch_size = batch_activations.shape[0] 13 | 14 | if exp_cfg.activation_aggregation is None: 15 | # only save the activations for the required indices 16 | offset = step * exp_cfg.batch_size * batch_activations.shape[1] 17 | batch_end = cur_batch_size * batch_activations.shape[1] 18 | batch_activations = einops.rearrange( 19 | batch_activations, 'b c d -> (b c) d') # batch, context, dim 20 | processed_activations = batch_activations[index_mask[offset:offset+batch_end]] 21 | 22 | elif exp_cfg.activation_aggregation == 'mean': 23 | # average over the context dimension for valid tokens only 24 | offset = step * exp_cfg.batch_size 25 | batch_mask = index_mask[offset: offset+cur_batch_size, :, None] 26 | masked_activations = batch_activations * batch_mask 27 | batch_valid_ixs = index_mask[offset:offset+cur_batch_size].sum(dim=1) 28 | processed_activations = masked_activations.sum( 29 | dim=1) / batch_valid_ixs[:, None] 30 | 31 | elif exp_cfg.activation_aggregation == 'max': 32 | # max over the context dimension for valid tokens only (set invalid tokens to -1) 33 | offset = step * exp_cfg.batch_size 34 | batch_mask = index_mask[offset: offset+cur_batch_size, :, None].to(int) 35 | masked_activations = batch_activations * batch_mask + (batch_mask - 1) 36 | processed_activations = batch_activations.max(dim=1)[0] 37 | 38 | return processed_activations 39 | 40 | 41 | @torch.no_grad() 42 | def get_activation_dataset( 43 | exp_cfg, model, text_dataset, layer_ix, index_mask=None 44 | ): 45 | hook_pt = f'blocks.{layer_ix}.{exp_cfg.probe_location}' 46 | 47 | dataloader = DataLoader( 48 | text_dataset['tokens'], batch_size=exp_cfg.batch_size, shuffle=False) 49 | layer_activations = [] 50 | 51 | for step, batch in enumerate(dataloader): 52 | model.run_with_hooks( 53 | batch, 54 | fwd_hooks=[(hook_pt, save_activation)], 55 | stop_at_layer=layer_ix + 1, 56 | ) 57 | 58 | batch_activations = model.hook_dict[hook_pt].ctx['activation'] 59 | 60 | processed_activations = process_activation_batch( 61 | exp_cfg, batch_activations, step, index_mask) 62 | 63 | layer_activations.append(processed_activations) 64 | model.reset_hooks() 65 | 66 | activation_dataset = torch.concat(layer_activations, dim=0).numpy() 67 | return activation_dataset 68 | -------------------------------------------------------------------------------- /experiments/common_configs.py: -------------------------------------------------------------------------------- 1 | # from ..config import ExperimentConfig 2 | 3 | # exp_cfg = ExperimentConfig( 4 | # experiment_name=args['experiment_name'], 5 | # experiment_type=args['experiment_type'], 6 | # model_name=args['model'], 7 | # dataset_cfg=dataset_cfg, 8 | # feature_datasets=args['feature_datasets'], 9 | # min_pos_examples=args['min_pos_examples'], 10 | # feature_dataset_size=args['feature_dataset_size'], 11 | # feature_dataset_test_frac=args['feature_dataset_test_frac'], 12 | # osp_upto_k=args['osp_upto_k'], 13 | # osp_heuristic_filter_size=args['osp_heuristic_filter_size'], 14 | # gurobi_timeout=args['gurobi_timeout'], 15 | # gurobi_verbose=args['gurobi_verbose'], 16 | # ) -------------------------------------------------------------------------------- /experiments/layer_cached_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformer_lens 3 | 4 | 5 | class LayerCachedModel: 6 | ''' 7 | Wrapper class for TransformerLens model to enable running a model as 8 | for each layer for each batch get_activations() rather than 9 | for each batch for each layer get_activations() 10 | ''' 11 | 12 | def __init__(self, model, dataloader): 13 | self.model = model 14 | self.dataloader = dataloader 15 | 16 | self.current_layer = -1 17 | self.cached_residual_stream = None 18 | 19 | def get_activations(self, layer): 20 | assert layer >= self.current_layer 21 | 22 | # TODO: run model from [current_layer] to [layer] while returning the activations 23 | # You may need to convert to float16 to avoid memory issues 24 | 25 | self.current_layer = layer 26 | 27 | 28 | # TODO: test correctness (cached model prediction == regular model prediction) 29 | # TODO: test speed of caching vs. redoing computation 30 | -------------------------------------------------------------------------------- /experiments/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.metrics import * 3 | 4 | 5 | def downsample_perf_curves(curve, pts_to_keep=100): 6 | n = len(curve) 7 | if n <= pts_to_keep: 8 | return curve 9 | else: 10 | idx = np.round(np.linspace(0, n - 1, pts_to_keep)).astype(int) 11 | return curve[idx] 12 | 13 | 14 | def get_binary_cls_perf_metrics(y_test, y_pred, y_score): 15 | precision, recall, _ = precision_recall_curve(y_test, y_score) 16 | fowlkes_mallows_index = (precision_score( 17 | y_test, y_pred) * recall_score(y_test, y_pred))**0.5 18 | classifier_results = { 19 | 'test_mcc': matthews_corrcoef(y_test, y_pred), 20 | 'test_cohen_kappa': cohen_kappa_score(y_test, y_pred), 21 | 'test_fmi': fowlkes_mallows_index, 22 | 'test_f1_score': f1_score(y_test, y_pred), 23 | 'test_f0.5_score': fbeta_score(y_test, y_pred, beta=0.5), 24 | 'test_f2_score': fbeta_score(y_test, y_pred, beta=2), 25 | 'test_pr_auc': auc(recall, precision), 26 | 'test_acc': accuracy_score(y_test, y_pred), 27 | 'test_balanced_acc': balanced_accuracy_score(y_test, y_pred), 28 | 'test_precision': precision_score(y_test, y_pred), 29 | 'test_recall': recall_score(y_test, y_pred), 30 | 'test_average_precision': average_precision_score(y_test, y_pred), 31 | 'test_roc_auc': roc_auc_score(y_test, y_score), 32 | 'test_precision_curve': downsample_perf_curves(precision), 33 | 'test_recall_curve': downsample_perf_curves(recall), 34 | } 35 | return classifier_results 36 | 37 | 38 | def get_regression_perf_metrics(y_test, y_pred): 39 | return { 40 | 'explained_variance': explained_variance_score(y_test, y_pred), 41 | 'max_error': max_error(y_test, y_pred), 42 | 'mean_absolute_error': mean_absolute_error(y_test, y_pred), 43 | 'mean_squared_error': mean_squared_error(y_test, y_pred), 44 | 'median_absolute_error': median_absolute_error(y_test, y_pred), 45 | 'r2': r2_score(y_test, y_pred), 46 | 'mean_absolute_percentage_error': mean_absolute_percentage_error( 47 | y_test, y_pred), 48 | 'd2_absolute_error': d2_absolute_error_score(y_test, y_pred), 49 | 'd2_pinball_score': d2_pinball_score(y_test, y_pred), 50 | 'd2_tweedie_score': d2_tweedie_score(y_test, y_pred), 51 | } 52 | -------------------------------------------------------------------------------- /experiments/regression_inner_loops.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | from sklearn.model_selection import train_test_split 4 | from sklearn.linear_model import ElasticNet 5 | from .regression_probes import * 6 | from .metrics import get_regression_perf_metrics 7 | 8 | 9 | def make_regression_k_list(): 10 | base_ks = [1, 2, 3, 4, 5, 6, 8, 9, 10, 12, 14] 11 | exp_ks = list((2 ** np.linspace(4, 8, 13)).astype(int)) 12 | return base_ks + exp_ks 13 | 14 | 15 | def dense_regression_probe(exp_cfg, activation_dataset, regression_target): 16 | """ 17 | Train a dense probe on the activation dataset. 18 | 19 | Parameters 20 | ---------- 21 | exp_cfg : as specified by the CLI in probing_experiment.py 22 | activation_dataset : np.ndarray (n_samples, n_neurons) 23 | regression_target : np.ndarray (n_samples) with regression targets. 24 | """ 25 | X_train, X_test, y_train, y_test = train_test_split( 26 | activation_dataset, regression_target, 27 | test_size=exp_cfg.test_set_frac, random_state=exp_cfg.seed) 28 | 29 | lr = ElasticNet(precompute=True) 30 | 31 | start_t = time.time() 32 | lr = lr.fit(X_train, y_train) 33 | elapsed_time = time.time() - start_t 34 | lr_pred = lr.predict(X_test) 35 | 36 | results = get_regression_perf_metrics(y_test, lr_pred) 37 | results['elapsed_time'] = elapsed_time 38 | results['n_iter'] = lr.n_iter_ 39 | results['coef'] = lr.coef_ 40 | return results 41 | 42 | 43 | def heuristic_sparse_regression_sweep(exp_cfg, activation_dataset, regression_target): 44 | """ 45 | Train a heuristic sparse probe on the activation dataset for varying k. 46 | 47 | Parameters 48 | ---------- 49 | exp_cfg : as specified by the CLI in probing_experiment.py 50 | activation_dataset : np.ndarray (n_samples, n_neurons) 51 | regression_target : np.ndarray (n_samples) with regression targets. 52 | """ 53 | X_train, X_test, y_train, y_test = train_test_split( 54 | activation_dataset, regression_target, 55 | test_size=exp_cfg.test_set_frac, random_state=exp_cfg.seed) 56 | 57 | neuron_ranking = get_heuristic_neuron_ranking_regression( 58 | X_train, y_train, 'f_stat') 59 | 60 | layer_results = {} 61 | for k in make_regression_k_list()[::-1]: 62 | support = np.sort(neuron_ranking[-k:]) 63 | lr = ElasticNet(precompute=True) 64 | start_t = time.time() 65 | lr = lr.fit(X_train[:, support], y_train) 66 | elapsed_time = time.time() - start_t 67 | 68 | lr_pred = lr.predict(X_test[:, support]) 69 | layer_results[k] = get_regression_perf_metrics(y_test, lr_pred) 70 | layer_results[k]['elapsed_time'] = elapsed_time 71 | layer_results[k]['n_iter'] = lr.n_iter_ 72 | layer_results[k]['coef'] = lr.coef_ 73 | layer_results[k]['support'] = support 74 | 75 | # rerank according to the linear regression coefficients 76 | neuron_ranking = np.zeros(len(neuron_ranking)) 77 | neuron_ranking[support] = np.abs(lr.coef_) 78 | neuron_ranking = np.argsort(neuron_ranking) 79 | 80 | return layer_results 81 | 82 | 83 | def optimal_sparse_regression_probe(exp_cfg, activation_dataset, regression_target): 84 | """ 85 | Train a sparse probe on the activation dataset. 86 | 87 | Parameters 88 | ---------- 89 | exp_cfg : as specified by the CLI in probing_experiment.py 90 | activation_dataset : np.ndarray (n_samples, n_neurons) 91 | regression_target : np.ndarray (n_samples) with regression targets. 92 | """ 93 | raise NotImplementedError 94 | -------------------------------------------------------------------------------- /experiments/regression_probes.py: -------------------------------------------------------------------------------- 1 | import gurobipy as gp 2 | import numpy as np 3 | from sklearn.feature_selection import f_regression, mutual_info_regression, r_regression 4 | from sklearn.linear_model import LinearRegression, Lasso 5 | 6 | 7 | def solve_inner_problem(X, Y, s, gamma): 8 | indices = np.where(s > 0.5)[0] 9 | n, d = X.shape 10 | denom = 2*n 11 | Xs = X[:, indices] 12 | 13 | alpha = Y - Xs @ (np.linalg.inv(np.eye(len(indices)) / 14 | gamma + Xs.T @ Xs) @ (Xs.T @ Y)) 15 | obj = np.dot(Y, alpha) / denom 16 | tmp = X.T @ alpha 17 | grad = -gamma * tmp**2 / denom 18 | return obj, grad 19 | 20 | 21 | def sparse_regression_oa(X, Y, k, gamma, s0, time_limit=60, verbose=True): 22 | n, d = X.shape 23 | 24 | gp_env = gp.Env() # need env for cluster 25 | model = gp.Model("classifier", env=gp_env) 26 | 27 | s = model.addVars(d, vtype=gp.GRB.BINARY, name="support") 28 | t = model.addVar(lb=0.0, vtype=gp.GRB.CONTINUOUS, name="objective") 29 | 30 | model.addConstr(gp.quicksum(s) <= k, name="l0") 31 | 32 | if len(s0) == 0: 33 | s0 = np.zeros(d) 34 | s0[range(int(k))] = 1 35 | 36 | obj0, grad0 = solve_inner_problem(X, Y, s0, gamma) 37 | model.addConstr( 38 | t >= obj0 + gp.quicksum(grad0[i] * (s[i] - s0[i]) for i in range(d))) 39 | model.setObjective(t, gp.GRB.MINIMIZE) 40 | 41 | def outer_approximation(model, where): 42 | if where == gp.GRB.Callback.MIPSOL: 43 | s_bar = model.cbGetSolution(model._vars) 44 | s_vals = np.array([a for a in s_bar.values()]) 45 | obj, grad = solve_inner_problem(X, Y, s_vals, gamma) 46 | model.cbLazy( 47 | t >= obj + gp.quicksum(grad[i] * (s[i] - s_vals[i]) for i in range(d))) 48 | 49 | model._vars = s 50 | model.params.OutputFlag = 1 if verbose else 0 51 | model.Params.lazyConstraints = 1 52 | model.Params.timeLimit = time_limit 53 | model.optimize(outer_approximation) 54 | 55 | support_indices = sorted([i for i in range(len(s)) if s[i].X > 0.5]) 56 | 57 | X_s = X[:, support_indices] 58 | beta = np.zeros(d) 59 | sol = np.linalg.solve(np.eye(int(k)) / gamma + X_s.T @ X_s, X_s.T @ Y) 60 | beta[support_indices] = gamma * X_s.T @ (Y - X_s @ sol) 61 | 62 | model_stats = { 63 | 'obj': model.ObjVal, 64 | 'obj_bound': model.ObjBound, 65 | 'mip_gap': model.MIPGap, 66 | 'model_status': model.Status, 67 | 'sol_count': model.SolCount, 68 | 'iter_count': model.IterCount, 69 | 'node_count': model.NodeCount, 70 | 'runtime': model.Runtime 71 | } 72 | 73 | model.dispose() 74 | gp_env.dispose() 75 | 76 | return model_stats, beta, support_indices 77 | 78 | 79 | def get_heuristic_neuron_ranking_regression(X, y, method): 80 | if method == 'l1': 81 | lr = Lasso() 82 | lr = lr.fit(X, y) 83 | ranks = np.argsort(np.abs(lr.coef_[0])) 84 | elif method == 'f_stat': 85 | f_stat, p_val = f_regression(X, y) 86 | ranks = np.argsort(f_stat) 87 | elif method == 'mi': 88 | mi = mutual_info_regression(X, y) 89 | ranks = np.argsort(mi) 90 | 91 | elif method == 'correlation': 92 | corr = r_regression(X, y) 93 | ranks = np.argsort(np.abs(corr)) 94 | else: 95 | raise ValueError('Invalid method') 96 | return ranks 97 | -------------------------------------------------------------------------------- /feature_selection_experiment.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import numpy as np 4 | import os 5 | import pickle 6 | from sklearn import svm 7 | from sklearn.feature_selection import f_classif, mutual_info_classif 8 | from sklearn.linear_model import LogisticRegression 9 | from sklearn.svm import LinearSVC 10 | from sklearn.model_selection import train_test_split 11 | from sklearn.metrics import * 12 | 13 | 14 | def gelu(x): 15 | return x / (1 + np.exp(-1.702 * x)) 16 | 17 | 18 | def fake_gelu_dataset(n, d, k, mean_shift=1, imbalance=0.5): 19 | x = np.random.normal(0, 1, size=(n, d)) 20 | true_support = np.random.choice(d, size=k, replace=False) 21 | true_classes = np.random.choice(n, size=int(n*imbalance), replace=False) 22 | y = np.zeros(n) 23 | y[true_classes] = 1 24 | y = y * 2 - 1 # to {+1, -1} 25 | 26 | x[:, true_support] += y[:, None] * mean_shift 27 | activations = gelu(x) 28 | return activations, y, true_support 29 | 30 | 31 | def log_uniform(a, b): 32 | return 10**random.uniform(math.log10(a), math.log10(b)) 33 | 34 | 35 | def score_features(feature_scores, s_set): 36 | k = len(s_set) 37 | top_features = set(np.argsort(feature_scores)[-k:]) 38 | return len(top_features.intersection(s_set)) / k 39 | 40 | 41 | def feature_selection_results(X, y, s): 42 | pos_class = (y * 2 - 1).astype(bool) 43 | mean_dif = X[pos_class, :].mean(axis=0) - X[~pos_class, :].mean(axis=0) 44 | 45 | mi = mutual_info_classif(X, y) 46 | 47 | f_stat, p_val = f_classif(X, y) 48 | 49 | lr = LogisticRegression(class_weight='balanced', 50 | penalty='l2', solver='saga', n_jobs=-1) 51 | lr = lr.fit(X, y) 52 | 53 | svm = LinearSVC(loss='hinge') 54 | svm = svm.fit(X, y) 55 | 56 | s_set = set(s) 57 | 58 | coeff_acc = { 59 | 'mean_dif': score_features(np.abs(mean_dif), s_set), 60 | 'mi': score_features(mi, s_set), 61 | 'f_stat': score_features(f_stat, s_set), 62 | 'lr_mag': score_features(np.abs(lr.coef_[0]), s_set), 63 | 'svm_mag': score_features(np.abs(svm.coef_[0]), s_set), 64 | } 65 | return coeff_acc 66 | 67 | 68 | def run_synthetic_data_experiment(n_trials): 69 | results = {} 70 | for i in range(n_trials): 71 | n = int(log_uniform(500, 30_000)) 72 | d = random.choice(50 * 2**np.arange(3, 8)) 73 | k = np.random.choice(np.arange(2, 20)) 74 | mean_shift = log_uniform(0.05, 5) 75 | class_imbalance = log_uniform(0.01, 0.5) 76 | 77 | X, y, s = fake_gelu_dataset( 78 | n, d, k, mean_shift=mean_shift, imbalance=class_imbalance) 79 | 80 | trial_results = feature_selection_results(X, y, s) 81 | results[i] = trial_results 82 | return results 83 | 84 | 85 | if __name__ == '__main__': 86 | EXPERIMENT_NAME = 'synthetic_data_feature_selection_comparison' 87 | N_TRIALS = 200 88 | 89 | seed = int(os.getenv('SLURM_ARRAY_TASK_ID', 1)) 90 | random.seed(seed) 91 | np.random.seed(seed) 92 | 93 | save_path = os.path.join( 94 | os.getenv('RESULTS_DIR', 'results'), EXPERIMENT_NAME) 95 | os.makedirs(save_path, exist_ok=True) 96 | file_name = f'shard_{seed}.p' 97 | 98 | results = run_synthetic_data_experiment(N_TRIALS) 99 | pickle.dump(results, open(os.path.join(save_path, file_name), 'wb')) 100 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-1.4b/biasxnorm_p20_50_80.csv: -------------------------------------------------------------------------------- 1 | layer,neuron,percentile,biasxnorm 2 | 1,1520,20,-0.5272459713197328 3 | 1,1937,50,-0.4403957690300331 4 | 1,542,80,-0.3060104747136503 5 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-160m/biasxnorm_p20_50_80.csv: -------------------------------------------------------------------------------- 1 | layer,neuron,percentile,biasxnorm 2 | 1,1981,20,-0.5375263263020287 3 | 1,297,50,-0.4137588585324323 4 | 1,1657,80,-0.2809535855534335 5 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-160m/top_compound_words.csv: -------------------------------------------------------------------------------- 1 | model,dataset,feature,layer,aggregation,hook_loc,neuron,test_mcc,test_f1_score,test_precision,test_recall,coef 2 | pythia-160m,compound_words.pyth.24.-1,prime-factors,1,none,mlp.hook_post,1611,0.9959840091200889,0.9967497291440953,0.9935205183585313,1.0,10.184043 3 | pythia-160m,compound_words.pyth.24.-1,social-security,2,none,mlp.hook_post,880,0.99098957035494,0.9927760577915377,0.985655737704918,1.0,10.384269 4 | pythia-160m,compound_words.pyth.24.-1,blood-pressure,1,none,mlp.hook_post,601,0.9811332282631159,0.984894259818731,0.9702380952380952,1.0,7.40047 5 | pythia-160m,compound_words.pyth.24.-1,social-media,3,none,mlp.hook_post,2347,0.979350773155441,0.9837940896091516,0.9735849056603774,0.9942196531791907,8.167716 6 | pythia-160m,compound_words.pyth.24.-1,north-america,1,none,mlp.hook_post,33,0.9698312186442014,0.9754010695187165,0.9539748953974896,0.9978118161925602,7.8169246 7 | pythia-160m,compound_words.pyth.24.-1,high-school,1,none,mlp.hook_post,1017,0.9647279050073813,0.9714285714285714,0.9444444444444444,1.0,6.387287 8 | pythia-160m,compound_words.pyth.24.-1,side-effects,2,none,mlp.hook_post,418,0.9646864661461422,0.971307120085016,0.9501039501039501,0.9934782608695653,5.0160966 9 | pythia-160m,compound_words.pyth.24.-1,living-room,3,none,mlp.hook_post,451,0.9546510073856874,0.9630434782608696,0.940552016985138,0.9866369710467706,7.284157 10 | pythia-160m,compound_words.pyth.24.-1,human-rights,1,none,mlp.hook_post,246,0.9484184100453665,0.9593956562795091,0.9355432780847146,0.9844961240310077,6.538414 11 | pythia-160m,compound_words.pyth.24.-1,mental-health,1,none,mlp.hook_post,1202,0.9482272379101402,0.9581589958158996,0.9346938775510204,0.9828326180257511,10.680516 12 | pythia-160m,compound_words.pyth.24.-1,cell-lines,1,none,mlp.hook_post,886,0.9383661955682125,0.9493670886075949,0.9090909090909091,0.9933774834437086,7.5405393 13 | pythia-160m,compound_words.pyth.24.-1,trial-court,1,none,mlp.hook_post,2492,0.9362852867052558,0.948905109489051,0.9362139917695473,0.9619450317124736,10.206216 14 | pythia-160m,compound_words.pyth.24.-1,public-health,1,none,mlp.hook_post,2530,0.9235850751204425,0.9383490073145245,0.9126016260162602,0.9655913978494624,8.073491 15 | pythia-160m,compound_words.pyth.24.-1,gene-expression,2,none,mlp.hook_post,1981,0.9164205914838274,0.9319796954314721,0.8861003861003861,0.9828693790149893,13.065329 16 | pythia-160m,compound_words.pyth.24.-1,third-party,1,none,mlp.hook_post,197,0.908994936818978,0.9263370332996973,0.8826923076923077,0.9745222929936306,4.7114515 17 | pythia-160m,compound_words.pyth.24.-1,magnetic-field,1,none,mlp.hook_post,444,0.8998210946626132,0.9190751445086704,0.8641304347826086,0.9814814814814815,6.211975 18 | pythia-160m,compound_words.pyth.24.-1,credit-card,2,none,mlp.hook_post,2628,0.8907060792054698,0.9100418410041842,0.8546168958742633,0.9731543624161074,7.9918036 19 | pythia-160m,compound_words.pyth.24.-1,control-group,1,none,mlp.hook_post,1555,0.871007406443597,0.896486229819563,0.8383658969804618,0.963265306122449,5.6667 20 | pythia-160m,compound_words.pyth.24.-1,clinical-trials,1,none,mlp.hook_post,276,0.8656409763093972,0.8893401015228426,0.8217636022514071,0.9690265486725663,7.3410864 21 | pythia-160m,compound_words.pyth.24.-1,federal-government,2,none,mlp.hook_post,874,0.8131255450135206,0.8461538461538461,0.7586206896551724,0.9565217391304348,10.181748 22 | pythia-160m,compound_words.pyth.24.-1,second-derivative,2,none,mlp.hook_post,2366,0.7568112770745069,0.8020477815699659,0.6921944035346097,0.9533468559837728,7.739948 23 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-1b/biasxnorm_p20_50_80.csv: -------------------------------------------------------------------------------- 1 | layer,neuron,percentile,biasxnorm 2 | 1,1686,20,-0.8611524307076479 3 | 1,127,50,-0.6881529589863682 4 | 1,7533,80,-0.41686048263435005 5 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-1b/compound_superposition.csv: -------------------------------------------------------------------------------- 1 | feature,layer,neuron 2 | prime-factors,2,139 3 | prime-factors,2,352 4 | prime-factors,2,2650 5 | prime-factors,2,4500 6 | prime-factors,2,5717 7 | human-rights,1,1496 8 | human-rights,1,3964 9 | human-rights,1,4408 10 | human-rights,1,4830 11 | human-rights,1,7445 12 | social-security,2,611 13 | social-security,2,4081 14 | social-security,2,6819 15 | social-security,2,7225 16 | social-security,2,7613 17 | federal-government,3,303 18 | federal-government,3,2886 19 | federal-government,3,3451 20 | federal-government,3,4728 21 | federal-government,3,7947 22 | second-derivative,2,1427 23 | second-derivative,2,2110 24 | second-derivative,2,3830 25 | second-derivative,2,3969 26 | second-derivative,2,6646 27 | social-media,3,1042 28 | social-media,3,1230 29 | social-media,3,1463 30 | social-media,3,2471 31 | social-media,3,4656 32 | blood-pressure,2,17 33 | blood-pressure,2,663 34 | blood-pressure,2,902 35 | blood-pressure,2,5639 36 | blood-pressure,2,5820 37 | north-america,1,2124 38 | north-america,1,3567 39 | north-america,1,6457 40 | north-america,1,6675 41 | north-america,1,6692 42 | living-room,3,163 43 | living-room,3,3521 44 | living-room,3,3558 45 | living-room,3,5671 46 | living-room,3,7053 47 | gene-expression,1,901 48 | gene-expression,1,1185 49 | gene-expression,1,3596 50 | gene-expression,1,6105 51 | gene-expression,1,6765 52 | mental-health,1,161 53 | mental-health,1,3652 54 | mental-health,1,4659 55 | mental-health,1,5996 56 | mental-health,1,8163 57 | trial-court,2,140 58 | trial-court,2,2296 59 | trial-court,2,3927 60 | trial-court,2,4008 61 | trial-court,2,6312 62 | public-health,2,293 63 | public-health,2,2289 64 | public-health,2,3329 65 | public-health,2,4382 66 | public-health,2,5879 67 | high-school,2,1462 68 | high-school,2,4793 69 | high-school,2,5886 70 | high-school,2,6580 71 | high-school,2,7035 72 | cell-lines,3,103 73 | cell-lines,3,802 74 | cell-lines,3,3930 75 | cell-lines,3,5738 76 | cell-lines,3,7788 77 | credit-card,2,1039 78 | credit-card,2,1177 79 | credit-card,2,2537 80 | credit-card,2,5967 81 | credit-card,2,7964 82 | side-effects,1,57 83 | side-effects,1,2576 84 | side-effects,1,4530 85 | side-effects,1,6047 86 | side-effects,1,6266 87 | control-group,2,3455 88 | control-group,2,4414 89 | control-group,2,5234 90 | control-group,2,5709 91 | control-group,2,6175 92 | clinical-trials,2,911 93 | clinical-trials,2,3502 94 | clinical-trials,2,3712 95 | clinical-trials,2,7094 96 | clinical-trials,2,7889 97 | magnetic-field,1,965 98 | magnetic-field,1,1618 99 | magnetic-field,1,6053 100 | magnetic-field,1,6135 101 | magnetic-field,1,6647 102 | third-party,2,1668 103 | third-party,2,4813 104 | third-party,2,5274 105 | third-party,2,5794 106 | third-party,2,7629 107 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-1b/monosemantic_code_neurons.csv: -------------------------------------------------------------------------------- 1 | model,dataset,feature,layer,aggregation,hook_loc,neuron,test_mcc,test_f1_score,test_precision,test_recall,coef 2 | pythia-1b,programming_lang_id.pyth.512.-1,Go,6,mean,mlp.hook_post,3108,0.9647745660292385,0.9689119170984456,0.9790575916230366,0.958974358974359,8.226216 3 | pythia-1b,programming_lang_id.pyth.512.-1,PHP,9,mean,mlp.hook_post,7926,0.9553678431247704,0.9577464788732394,0.9902912621359223,0.9272727272727272,9.4361925 4 | pythia-1b,programming_lang_id.pyth.512.-1,Go,1,mean,mlp.hook_post,884,0.9467680510682587,0.9516129032258065,1.0,0.9076923076923077,18.493044 5 | pythia-1b,programming_lang_id.pyth.512.-1,Go,9,mean,mlp.hook_post,1823,0.9442817822227344,0.9509043927648579,0.9583333333333334,0.9435897435897436,13.344632 6 | pythia-1b,programming_lang_id.pyth.512.-1,Python,10,mean,mlp.hook_post,3855,0.9400676873103173,0.945054945054945,0.9555555555555556,0.9347826086956522,8.521176 7 | pythia-1b,programming_lang_id.pyth.512.-1,Python,9,mean,mlp.hook_post,1693,0.9354155185612765,0.9402985074626865,0.9692307692307692,0.9130434782608695,10.610265 8 | pythia-1b,programming_lang_id.pyth.512.-1,Python,6,mean,mlp.hook_post,7172,0.9352624971914606,0.9398496240601504,0.9765625,0.9057971014492754,13.13241 9 | pythia-1b,programming_lang_id.pyth.512.-1,Java,6,mean,mlp.hook_post,4070,0.9124065157258477,0.9263565891472868,0.9263565891472868,0.9263565891472868,12.90993 10 | pythia-1b,programming_lang_id.pyth.512.-1,PHP,10,mean,mlp.hook_post,5633,0.9042516486109267,0.9107142857142856,0.8947368421052632,0.9272727272727272,10.540982 11 | pythia-1b,programming_lang_id.pyth.512.-1,PHP,9,mean,mlp.hook_post,1808,0.8903332923046247,0.897196261682243,0.9230769230769231,0.8727272727272727,14.229025 12 | pythia-1b,programming_lang_id.pyth.512.-1,Java,10,mean,mlp.hook_post,3887,0.874336737217571,0.8942486085343229,0.8576512455516014,0.9341085271317829,4.555989 13 | pythia-1b,programming_lang_id.pyth.512.-1,Java,11,mean,mlp.hook_post,1870,0.8571495805676833,0.8798521256931607,0.8409893992932862,0.9224806201550387,14.687348 14 | pythia-1b,programming_lang_id.pyth.512.-1,C,1,mean,mlp.hook_post,3398,0.8183904415830677,0.8463073852295409,0.8617886178861789,0.8313725490196079,16.776478 15 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-2.8b/biasxnorm_p20_50_80.csv: -------------------------------------------------------------------------------- 1 | layer,neuron,percentile,biasxnorm 2 | 1,4304,20,-0.7444907452201974 3 | 1,1354,50,-0.5961270867583295 4 | 1,3106,80,-0.31648927743026434 5 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-2.8b/top_compound_words.csv: -------------------------------------------------------------------------------- 1 | model,dataset,feature,layer,aggregation,hook_loc,neuron,test_mcc,test_f1_score,test_precision,test_recall,coef 2 | pythia-2.8b,compound_words.pyth.24.-1,prime-factors,2,none,mlp.hook_post,9771,0.9986550999226772,0.9989118607181718,1.0,0.9978260869565218,6.9789343 3 | pythia-2.8b,compound_words.pyth.24.-1,living-room,4,none,mlp.hook_post,6737,0.9918465673980807,0.9933628318584071,0.9868131868131869,1.0,6.4259844 4 | pythia-2.8b,compound_words.pyth.24.-1,social-media,5,none,mlp.hook_post,1107,0.987785803297301,0.9904214559386973,0.9847619047619047,0.9961464354527938,5.8557606 5 | pythia-2.8b,compound_words.pyth.24.-1,social-security,3,none,mlp.hook_post,8199,0.9832833384995171,0.9866117404737383,0.9775510204081632,0.9958419958419958,6.2409496 6 | pythia-2.8b,compound_words.pyth.24.-1,second-derivative,5,none,mlp.hook_post,8905,0.9821605562823835,0.9858299595141702,0.9838383838383838,0.9878296146044625,12.588413 7 | pythia-2.8b,compound_words.pyth.24.-1,mental-health,3,none,mlp.hook_post,4521,0.9619323557051809,0.9693121693121693,0.9561586638830898,0.9828326180257511,5.0580335 8 | pythia-2.8b,compound_words.pyth.24.-1,blood-pressure,3,none,mlp.hook_post,6915,0.9615481205546845,0.9691542288557214,0.9437984496124031,0.9959100204498977,6.87574 9 | pythia-2.8b,compound_words.pyth.24.-1,high-school,3,none,mlp.hook_post,1514,0.9529504394040703,0.9621289662231322,0.93812375249501,0.9873949579831933,5.0860415 10 | pythia-2.8b,compound_words.pyth.24.-1,trial-court,2,none,mlp.hook_post,6559,0.9523787712311456,0.9613034623217922,0.9273084479371316,0.9978858350951374,9.633029 11 | pythia-2.8b,compound_words.pyth.24.-1,public-health,3,none,mlp.hook_post,3567,0.9490190107696157,0.9585062240663901,0.9258517034068137,0.9935483870967742,5.8376412 12 | pythia-2.8b,compound_words.pyth.24.-1,side-effects,2,none,mlp.hook_post,5230,0.9464102662587212,0.9561586638830897,0.9196787148594378,0.9956521739130435,9.318477 13 | pythia-2.8b,compound_words.pyth.24.-1,cell-lines,2,none,mlp.hook_post,1918,0.9315006661697458,0.9435146443514644,0.8966202783300199,0.9955849889624724,4.922087 14 | pythia-2.8b,compound_words.pyth.24.-1,human-rights,3,none,mlp.hook_post,9341,0.9314401655687998,0.9458955223880597,0.9118705035971223,0.9825581395348837,9.791269 15 | pythia-2.8b,compound_words.pyth.24.-1,federal-government,4,none,mlp.hook_post,7480,0.920305997433491,0.9352818371607515,0.8995983935742972,0.9739130434782609,6.236562 16 | pythia-2.8b,compound_words.pyth.24.-1,gene-expression,3,none,mlp.hook_post,3090,0.9190692232677301,0.9344262295081969,0.8958742632612967,0.9764453961456103,10.129644 17 | pythia-2.8b,compound_words.pyth.24.-1,control-group,4,none,mlp.hook_post,6166,0.9092214950266629,0.9272550921435498,0.8835489833641405,0.9755102040816327,6.05653 18 | pythia-2.8b,compound_words.pyth.24.-1,credit-card,2,none,mlp.hook_post,9575,0.9064152388518025,0.9225941422594142,0.8664047151277013,0.9865771812080537,5.7136483 19 | pythia-2.8b,compound_words.pyth.24.-1,north-america,2,none,mlp.hook_post,1929,0.8963757043716488,0.9142280524722503,0.848314606741573,0.9912472647702407,7.7785225 20 | pythia-2.8b,compound_words.pyth.24.-1,third-party,3,none,mlp.hook_post,9675,0.893091631610025,0.912109375,0.8444846292947559,0.9915074309978769,4.264713 21 | pythia-2.8b,compound_words.pyth.24.-1,clinical-trials,2,none,mlp.hook_post,1742,0.8647535862986758,0.8877755511022043,0.8113553113553114,0.9800884955752213,6.3994236 22 | pythia-2.8b,compound_words.pyth.24.-1,magnetic-field,2,none,mlp.hook_post,9694,0.8578589130693968,0.8851224105461394,0.8159722222222222,0.9670781893004116,4.850628 23 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-2.8b/wikidata.csv: -------------------------------------------------------------------------------- 1 | model,dataset,feature,layer,neuron,test_mcc,test_f1_score,coef 2 | pythia-2.8b,wikidata_occupation.pyth.128.30000,actor,19,2176,0.349,0.419,7.259 3 | pythia-2.8b,wikidata_occupation.pyth.128.30000,actor,18,3679,0.345,0.419,7.383 4 | pythia-2.8b,wikidata_occupation.pyth.128.30000,actor,22,5361,0.325,0.403,8.036 5 | pythia-2.8b,wikidata_occupation.pyth.128.30000,association football player,18,5777,0.462,0.485,9.967 6 | pythia-2.8b,wikidata_occupation.pyth.128.30000,association football player,21,122,0.313,0.382,7.445 7 | pythia-2.8b,wikidata_occupation.pyth.128.30000,association football player,19,5431,0.315,0.379,6.723 8 | pythia-2.8b,wikidata_occupation.pyth.128.30000,basketball player,17,8218,0.391,0.409,6.780 9 | pythia-2.8b,wikidata_occupation.pyth.128.30000,basketball player,28,427,0.354,0.408,6.969 10 | pythia-2.8b,wikidata_occupation.pyth.128.30000,basketball player,18,2748,0.356,0.394,5.570 11 | pythia-2.8b,wikidata_occupation.pyth.128.30000,businessperson,8,10149,0.229,0.297,5.330 12 | pythia-2.8b,wikidata_occupation.pyth.128.30000,businessperson,23,7507,0.255,0.294,4.386 13 | pythia-2.8b,wikidata_occupation.pyth.128.30000,businessperson,8,2304,0.219,0.277,3.801 14 | pythia-2.8b,wikidata_occupation.pyth.128.30000,journalist,12,4755,0.301,0.376,7.181 15 | pythia-2.8b,wikidata_occupation.pyth.128.30000,journalist,22,6182,0.297,0.372,6.458 16 | pythia-2.8b,wikidata_occupation.pyth.128.30000,journalist,8,1425,0.285,0.363,7.059 17 | pythia-2.8b,wikidata_occupation.pyth.128.30000,lawyer,8,2503,0.257,0.346,6.342 18 | pythia-2.8b,wikidata_occupation.pyth.128.30000,lawyer,21,3745,0.248,0.339,5.342 19 | pythia-2.8b,wikidata_occupation.pyth.128.30000,lawyer,15,5167,0.243,0.335,5.324 20 | pythia-2.8b,wikidata_occupation.pyth.128.30000,politician,19,5225,0.264,0.347,5.674 21 | pythia-2.8b,wikidata_occupation.pyth.128.30000,politician,8,4334,0.234,0.321,4.781 22 | pythia-2.8b,wikidata_occupation.pyth.128.30000,politician,6,7823,0.229,0.318,5.694 23 | pythia-2.8b,wikidata_occupation.pyth.128.30000,researcher,10,9109,0.271,0.332,4.790 24 | pythia-2.8b,wikidata_occupation.pyth.128.30000,researcher,13,5933,0.244,0.324,4.982 25 | pythia-2.8b,wikidata_occupation.pyth.128.30000,researcher,13,5257,0.228,0.313,5.293 26 | pythia-2.8b,wikidata_occupation.pyth.128.30000,singer,20,8583,0.440,0.496,8.883 27 | pythia-2.8b,wikidata_occupation.pyth.128.30000,singer,20,9611,0.383,0.447,8.726 28 | pythia-2.8b,wikidata_occupation.pyth.128.30000,singer,13,9463,0.399,0.445,7.862 29 | pythia-2.8b,wikidata_occupation.pyth.128.30000,writer,22,3273,0.278,0.356,5.784 30 | pythia-2.8b,wikidata_occupation.pyth.128.30000,writer,9,934,0.247,0.334,5.415 31 | pythia-2.8b,wikidata_occupation.pyth.128.30000,writer,18,6992,0.251,0.333,5.822 -------------------------------------------------------------------------------- /interpretable_neurons/pythia-410m/biasxnorm_p20_50_80.csv: -------------------------------------------------------------------------------- 1 | layer,neuron,percentile,biasxnorm 2 | 1,3231,20,-0.27615977657066537 3 | 1,149,50,-0.2093177505968562 4 | 1,4003,80,-0.1399363165214247 5 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-6.9b/biasxnorm_p20_50_80.csv: -------------------------------------------------------------------------------- 1 | layer,neuron,percentile,biasxnorm 2 | 1,132,20,-1.6590305880544065 3 | 1,14104,50,-1.3974292860494018 4 | 1,13190,80,-1.1011086758597344 5 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-6.9b/monosemantic_distribution_neurons.csv: -------------------------------------------------------------------------------- 1 | model,dataset,feature,layer,aggregation,hook_loc,neuron,test_mcc,test_f1_score,test_precision,test_recall,coef 2 | pythia-6.9b,distribution_id.pyth.512.-1,uspto,8,mean,mlp.hook_post,1816,0.9893209927879815,0.990625,1.0,0.9814241486068112,13.036818 3 | pythia-6.9b,distribution_id.pyth.512.-1,arxiv,31,mean,mlp.hook_post,13649,0.9826337132679703,0.986351228389445,0.9963235294117647,0.9765765765765766,15.236027 4 | pythia-6.9b,distribution_id.pyth.512.-1,uspto,7,mean,mlp.hook_post,7805,0.9768085321830406,0.9797191887675507,0.9874213836477987,0.9721362229102167,24.992254 5 | pythia-6.9b,distribution_id.pyth.512.-1,freelaw,14,mean,mlp.hook_post,12806,0.9710553525760423,0.9760348583877996,0.9977728285077951,0.9552238805970149,22.229414 6 | pythia-6.9b,distribution_id.pyth.512.-1,arxiv,27,mean,mlp.hook_post,11263,0.9642238442891552,0.9720972097209721,0.9712230215827338,0.972972972972973,15.291997 7 | pythia-6.9b,distribution_id.pyth.512.-1,freelaw,1,mean,mlp.hook_post,8710,0.9611194179436281,0.9683544303797468,0.9582463465553236,0.9786780383795309,-37.31863 8 | pythia-6.9b,distribution_id.pyth.512.-1,uspto,7,mean,mlp.hook_post,14461,0.9608491936304272,0.9658385093167702,0.9688473520249221,0.9628482972136223,27.479229 9 | pythia-6.9b,distribution_id.pyth.512.-1,arxiv,30,mean,mlp.hook_post,7320,0.9605321440737165,0.9685767097966729,0.9943074003795066,0.9441441441441442,21.24965 10 | pythia-6.9b,distribution_id.pyth.512.-1,freelaw,7,mean,mlp.hook_post,15782,0.951174970941705,0.9590254706533776,0.9976958525345622,0.9232409381663113,25.9928 11 | pythia-6.9b,distribution_id.pyth.512.-1,enron,17,mean,mlp.hook_post,15140,0.94834491091295,0.9510489510489512,0.9315068493150684,0.9714285714285714,14.508387 12 | pythia-6.9b,distribution_id.pyth.512.-1,enron,17,mean,mlp.hook_post,14623,0.9302838923103667,0.9333333333333333,0.9692307692307692,0.9,22.055277 13 | pythia-6.9b,distribution_id.pyth.512.-1,enron,22,mean,mlp.hook_post,15422,0.9194359394150236,0.9225589225589225,0.8726114649681529,0.9785714285714285,10.719464 14 | pythia-6.9b,distribution_id.pyth.512.-1,hackernews,30,mean,mlp.hook_post,5233,0.916887246381335,0.9205479452054796,0.9882352941176471,0.8615384615384616,23.487501 15 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-6.9b/top_fact_neurons.csv: -------------------------------------------------------------------------------- 1 | model,dataset,feature,layer,agg,loc,k,test_mcc,test_f1_score,test_precision,test_recall,coef,neuron 2 | pythia-6.9b,wikidata_sorted_occupation_athlete.pyth.128.5000,association football player,19,max,mlp.hook_post,1,0.8723107680581776,0.8893280632411067,0.9782608695652174,0.8152173913043478,12.316481,10761 3 | pythia-6.9b,wikidata_sorted_occupation_athlete.pyth.128.5000,baseball player,19,max,mlp.hook_post,1,0.8693890385607322,0.8888888888888888,0.9626556016597511,0.8256227758007118,12.106168,549 4 | pythia-6.9b,wikidata_sorted_occupation_athlete.pyth.128.5000,baseball player,20,max,mlp.hook_post,1,0.8508430742632211,0.8740458015267176,0.9423868312757202,0.8149466192170819,12.066226,13139 5 | pythia-6.9b,wikidata_sorted_occupation_athlete.pyth.128.5000,baseball player,21,max,mlp.hook_post,1,0.8272866490245298,0.8532818532818532,0.9324894514767933,0.7864768683274022,11.29498,16105 6 | pythia-6.9b,wikidata_sorted_occupation_athlete.pyth.128.5000,American football player,20,max,mlp.hook_post,1,0.8153000990286461,0.8477508650519031,0.9176029962546817,0.7877813504823151,11.799308,10306 7 | pythia-6.9b,wikidata_sorted_occupation_athlete.pyth.128.5000,ice hockey player,20,max,mlp.hook_post,1,0.7744443892864911,0.8066914498141263,0.9234042553191489,0.7161716171617162,10.730627,2367 8 | pythia-6.9b,wikidata_sorted_sex_or_gender.pyth.128.6000,female,12,max,mlp.hook_post,1,0.9498274537957756,0.9738636363636364,0.9965116279069768,0.9522222222222222,17.750957,4043 9 | pythia-6.9b,wikidata_sorted_is_alive.pyth.128.6000,true,17,max,mlp.hook_post,1,0.7848760827593265,0.8762254901960784,0.959731543624161,0.8060879368658399,14.188927,5653 10 | pythia-6.9b,wikidata_sorted_sex_or_gender.pyth.128.6000,male,9,max,mlp.hook_post,1,0.7454832218856192,0.8522378908645003,0.9507523939808481,0.7722222222222223,13.740647,996 11 | pythia-6.9b,wikidata_sorted_occupation.pyth.128.6000,athlete,9,max,mlp.hook_post,1,0.7829941416093861,0.8112149532710281,0.8930041152263375,0.7431506849315068,11.93753,12997 12 | pythia-6.9b,wikidata_sorted_is_alive.pyth.128.6000,false,14,max,mlp.hook_post,1,0.6919494733083441,0.8051118210862619,0.9662576687116564,0.6900328587075575,11.010353,205 13 | pythia-6.9b,wikidata_sorted_occupation.pyth.128.6000,actor,9,max,mlp.hook_post,1,0.7222237492560692,0.7687188019966723,0.7751677852348994,0.7623762376237624,10.875046,4502 14 | pythia-6.9b,wikidata_sorted_occupation.pyth.128.6000,singer,11,max,mlp.hook_post,1,0.7049523913398199,0.7380497131931166,0.8654708520179372,0.6433333333333333,9.797108,12667 15 | pythia-6.9b,wikidata_sorted_occupation_athlete.pyth.128.5000,basketball player,19,max,mlp.hook_post,1,0.5705885545300837,0.6546052631578947,0.7132616487455197,0.6048632218844985,8.143844,3520 16 | pythia-6.9b,wikidata_sorted_occupation.pyth.128.6000,journalist,9,max,mlp.hook_post,1,0.6054755652697862,0.653211009174312,0.7876106194690266,0.5579937304075235,8.728869,3974 17 | pythia-6.9b,wikidata_sorted_occupation.pyth.128.6000,politician,9,max,mlp.hook_post,1,0.5740938565534484,0.6202783300198806,0.7609756097560976,0.5234899328859061,6.798583,613 18 | pythia-6.9b,wikidata_sorted_political_party.pyth.128.3000,Republican Party,7,max,mlp.hook_post,1,0.22512247179808548,0.5772946859903382,0.6322751322751323,0.5311111111111111,1.8690517,3710 19 | pythia-6.9b,wikidata_sorted_political_party.pyth.128.3000,Democratic Party,18,max,mlp.hook_post,1,0.3001225239993905,0.5609756097560976,0.71875,0.46,1.4787625,15606 20 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-70m/biasxnorm_p20_50_80.csv: -------------------------------------------------------------------------------- 1 | layer,neuron,percentile,biasxnorm 2 | 1,828,20,-0.6336406533373165 3 | 1,221,50,-0.357005106965687 4 | 1,87,80,-0.16662166286988622 5 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-70m/monosemantic_language_neurons.csv: -------------------------------------------------------------------------------- 1 | model,dataset,feature,layer,aggregation,hook_loc,neuron,test_mcc,test_f1_score,test_precision,test_recall,coef 2 | pythia-70m,natural_lang_id.pyth.512.-1,de,3,mean,mlp.hook_post,343,1.0,1.0,1.0,1.0,5.429084 3 | pythia-70m,natural_lang_id.pyth.512.-1,fr,3,mean,mlp.hook_post,609,1.0,1.0,1.0,1.0,5.855191 4 | pythia-70m,natural_lang_id.pyth.512.-1,it,4,mean,mlp.hook_post,627,1.0,1.0,1.0,1.0,6.572351 5 | pythia-70m,natural_lang_id.pyth.512.-1,el,5,mean,mlp.hook_post,1434,1.0,1.0,1.0,1.0,6.6436276 6 | pythia-70m,natural_lang_id.pyth.512.-1,el,5,mean,mlp.hook_post,1645,1.0,1.0,1.0,1.0,8.188068 7 | pythia-70m,natural_lang_id.pyth.512.-1,pt,4,mean,mlp.hook_post,1986,0.9993983932000134,0.9994649545211343,1.0,0.9989304812834224,8.316875 8 | pythia-70m,natural_lang_id.pyth.512.-1,el,0,mean,mlp.hook_post,1006,1.0,1.0,1.0,1.0,8.806415 9 | pythia-70m,natural_lang_id.pyth.512.-1,it,5,mean,mlp.hook_post,1856,1.0,1.0,1.0,1.0,9.4072895 10 | pythia-70m,natural_lang_id.pyth.512.-1,it,4,mean,mlp.hook_post,45,1.0,1.0,1.0,1.0,9.975465 11 | pythia-70m,natural_lang_id.pyth.512.-1,nl,3,mean,mlp.hook_post,786,0.999397264142742,0.9994638069705095,1.0,0.9989281886387996,10.48039 12 | pythia-70m,natural_lang_id.pyth.512.-1,sv,4,mean,mlp.hook_post,230,1.0,1.0,1.0,1.0,10.742738 13 | pythia-70m,natural_lang_id.pyth.512.-1,de,5,mean,mlp.hook_post,894,1.0,1.0,1.0,1.0,11.527867 14 | pythia-70m,natural_lang_id.pyth.512.-1,fr,5,mean,mlp.hook_post,751,1.0,1.0,1.0,1.0,11.803505 15 | pythia-70m,natural_lang_id.pyth.512.-1,fr,5,mean,mlp.hook_post,670,1.0,1.0,1.0,1.0,12.895436 16 | pythia-70m,natural_lang_id.pyth.512.-1,sv,3,mean,mlp.hook_post,98,0.9993880511967762,0.9994544462629569,1.0,0.9989094874591058,13.709608 17 | pythia-70m,natural_lang_id.pyth.512.-1,de,4,mean,mlp.hook_post,1559,1.0,1.0,1.0,1.0,13.793328 18 | pythia-70m,natural_lang_id.pyth.512.-1,nl,4,mean,mlp.hook_post,2047,1.0,1.0,1.0,1.0,13.881838 19 | pythia-70m,natural_lang_id.pyth.512.-1,sv,4,mean,mlp.hook_post,21,1.0,1.0,1.0,1.0,14.509701 20 | pythia-70m,natural_lang_id.pyth.512.-1,en,5,mean,mlp.hook_post,122,0.999401751056315,0.9994683678894205,1.0,0.9989373007438895,14.87958 21 | pythia-70m,natural_lang_id.pyth.512.-1,en,5,mean,mlp.hook_post,765,1.0,1.0,1.0,1.0,15.815285 22 | pythia-70m,natural_lang_id.pyth.512.-1,en,5,mean,mlp.hook_post,285,0.999401751056315,0.9994683678894205,1.0,0.9989373007438895,16.341928 23 | pythia-70m,natural_lang_id.pyth.512.-1,es,5,mean,mlp.hook_post,632,0.9994061609462219,0.9994728518713759,1.0,0.9989462592202318,17.05969 24 | pythia-70m,natural_lang_id.pyth.512.-1,nl,5,mean,mlp.hook_post,987,0.999397264142742,0.9994638069705095,1.0,0.9989281886387996,18.002659 25 | pythia-70m,natural_lang_id.pyth.512.-1,es,4,mean,mlp.hook_post,292,0.9994061609462219,0.9994728518713759,1.0,0.9989462592202318,18.887386 26 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-70m/pythia70m_prime_factor_neurons.csv: -------------------------------------------------------------------------------- 1 | layer,neuron,test_mcc 2 | 1,1117,0.998657327028896 3 | 1,1079,0.9986550999226772 4 | 1,111,0.9973186705998264 5 | 1,749,0.9959840091200889 6 | 1,828,0.9946533211589706 7 | 1,1926,0.9946533211589706 8 | 1,1761,0.994633736771317 9 | 1,709,0.9933265854461832 10 | 1,1644,0.9933012841286317 11 | 2,2026,0.9933012841286317 12 | 2,1535,0.9933012841286317 13 | 1,1494,0.9933012841286317 14 | 1,1748,0.9920037808703182 15 | 2,1136,0.9920037808703182 16 | 1,702,0.9920037808703182 17 | 2,1047,0.9919727965780535 18 | 1,953,0.9906848864773178 19 | 5,27,0.9906482528935814 20 | 2,575,0.9893698814689665 21 | 1,367,0.9880587452013986 22 | 3,1843,0.9880109130101395 23 | 1,1250,0.9854479970760662 24 | 2,238,0.9840839605658703 25 | 1,827,0.9828524799817793 26 | 2,1142,0.9827826431046657 27 | 2,652,0.9827826431046657 28 | 2,1402,0.9814851252826131 29 | 1,1956,0.9802720341759834 30 | 4,1458,0.9801913871003504 31 | 1,36,0.9789874137256388 32 | 2,1597,0.9789014087053487 33 | 1,630,0.9777065022483363 34 | 1,670,0.9777065022483363 35 | 2,1698,0.9776151703905261 36 | 2,776,0.97515572907696 37 | 1,1852,0.97515572907696 38 | 2,585,0.97515572907696 39 | 3,1639,0.9750538358921039 40 | 2,1944,0.9748692013374924 41 | 2,1232,0.9738858291644985 42 | 2,105,0.9737787010093012 43 | 2,103,0.972619561785693 44 | 1,1382,0.9725072288056059 45 | 1,1799,0.9725072288056059 46 | 2,925,0.9700978497023146 47 | 1,87,0.9700978497023146 48 | 4,1106,0.9698587285342519 49 | 1,124,0.9675904443177165 50 | 2,967,0.9675904443177165 51 | 2,666,0.964954257388028 52 | 2,653,0.9626179701854413 53 | 3,1510,0.9609277075181428 54 | 1,655,0.9589250959815859 55 | 2,923,0.9589250959815859 56 | 1,1724,0.9587574340034344 57 | 3,1221,0.9585955769175792 58 | 3,734,0.9580724421928709 59 | 5,616,0.9580072558629034 60 | 2,1984,0.9575284703525223 61 | 2,1051,0.9568965453585395 62 | 1,1902,0.956480288964665 63 | 1,218,0.9552629665216523 64 | 2,1962,0.9550808038687473 65 | 2,205,0.9526467036896862 66 | 2,931,0.9502260337600053 67 | 1,1918,0.94802910433467 68 | 2,1413,0.9454244498822085 69 | 3,1883,0.94440774771312 70 | 4,928,0.9423889384736618 71 | 3,56,0.9402151409772855 72 | 5,1615,0.9386055613287819 73 | 4,1811,0.9361936312434651 74 | 3,1409,0.9359766652116369 75 | 3,1531,0.934809950460351 76 | 4,536,0.9319636575143146 77 | 4,2015,0.9307975353462645 78 | 4,584,0.9287913974746623 79 | 4,1498,0.9283808959648853 80 | 3,1865,0.9267293803488141 81 | 4,1875,0.9262638642981481 82 | 2,1440,0.9255268511985371 83 | 2,450,0.9187964622218202 84 | 5,1446,0.9152947798541844 85 | 3,696,0.911299719742114 86 | 5,2014,0.9092169453086857 87 | 3,108,0.9062593433854699 88 | 3,2006,0.904901886019403 89 | 3,1500,0.904799130941486 90 | 3,724,0.9037125383643637 91 | 3,346,0.9030030608769091 92 | 4,1461,0.9029244469174559 93 | 4,1643,0.9003214470503625 94 | 4,1551,0.8958360489928912 95 | 3,1857,0.8953887183702398 96 | 5,452,0.8936822573999782 97 | 3,675,0.8911884445504328 98 | 3,130,0.8911884445504328 99 | 4,1849,0.8894453223199775 100 | 5,1746,0.8878506210691754 101 | 4,1460,0.8859827962477783 102 | -------------------------------------------------------------------------------- /interpretable_neurons/pythia-70m/wikidata.csv: -------------------------------------------------------------------------------- 1 | model,dataset,feature,layer,neuron,test_mcc,test_f1_score,coef 2 | pythia-70m,wikidata_sex_or_gender.pyth.128.10000,female,3,897,0.602,0.778,18.215 3 | pythia-70m,wikidata_sex_or_gender.pyth.128.10000,female,5,1333,0.528,0.719,15.590 4 | pythia-70m,wikidata_sex_or_gender.pyth.128.10000,female,1,1368,0.168,0.599,-3.193 5 | pythia-70m,wikidata_sex_or_gender.pyth.128.10000,female,4,1456,0.228,0.597,5.333 6 | pythia-70m,wikidata_sex_or_gender.pyth.128.10000,female,2,718,0.204,0.559,4.323 7 | pythia-70m,wikidata_sex_or_gender.pyth.128.10000,female,0,715,0.132,0.514,2.179 8 | pythia-70m,wikidata_sex_or_gender.pyth.128.10000,male,3,897,0.602,0.815,-18.217 9 | pythia-70m,wikidata_sex_or_gender.pyth.128.10000,male,5,1333,0.528,0.786,-15.550 10 | pythia-70m,wikidata_sex_or_gender.pyth.128.10000,male,2,718,0.204,0.637,-4.323 11 | pythia-70m,wikidata_sex_or_gender.pyth.128.10000,male,4,1456,0.228,0.630,-5.333 12 | pythia-70m,wikidata_sex_or_gender.pyth.128.10000,male,0,715,0.132,0.608,-2.175 13 | pythia-70m,wikidata_sex_or_gender.pyth.128.10000,male,1,1368,0.168,0.565,3.190 14 | pythia-70m,wikidata_is_alive.pyth.128.10000,false,3,925,0.297,0.677,-6.519 15 | pythia-70m,wikidata_is_alive.pyth.128.10000,false,4,688,0.278,0.645,-6.495 16 | pythia-70m,wikidata_is_alive.pyth.128.10000,false,1,472,0.212,0.641,-4.303 17 | pythia-70m,wikidata_is_alive.pyth.128.10000,false,0,320,0.106,0.604,-1.096 18 | pythia-70m,wikidata_is_alive.pyth.128.10000,false,2,1706,0.128,0.596,-2.917 19 | pythia-70m,wikidata_is_alive.pyth.128.10000,false,5,1791,0.264,0.537,6.591 20 | pythia-70m,wikidata_is_alive.pyth.128.10000,true,5,1791,0.264,0.688,-6.590 21 | pythia-70m,wikidata_is_alive.pyth.128.10000,true,4,688,0.278,0.630,6.500 22 | pythia-70m,wikidata_is_alive.pyth.128.10000,true,3,925,0.297,0.594,6.518 23 | pythia-70m,wikidata_is_alive.pyth.128.10000,true,1,472,0.212,0.547,4.304 24 | pythia-70m,wikidata_is_alive.pyth.128.10000,true,2,1706,0.128,0.516,2.918 25 | pythia-70m,wikidata_is_alive.pyth.128.10000,true,0,320,0.106,0.470,1.096 26 | pythia-70m,wikidata_occupation.pyth.128.30000,actor,5,885,0.160,0.266,4.030 27 | pythia-70m,wikidata_occupation.pyth.128.30000,association football player,5,637,0.224,0.320,5.264 28 | pythia-70m,wikidata_occupation.pyth.128.30000,basketball player,5,637,0.219,0.310,5.736 29 | pythia-70m,wikidata_occupation.pyth.128.30000,businessperson,5,1007,0.080,0.146,2.434 30 | pythia-70m,wikidata_occupation.pyth.128.30000,journalist,5,1613,0.122,0.240,4.173 31 | pythia-70m,wikidata_occupation.pyth.128.30000,lawyer,5,1078,0.174,0.270,4.721 32 | pythia-70m,wikidata_occupation.pyth.128.30000,politician,5,1233,0.209,0.305,5.714 33 | pythia-70m,wikidata_occupation.pyth.128.30000,researcher,4,303,0.187,0.289,-5.048 34 | pythia-70m,wikidata_occupation.pyth.128.30000,singer,5,873,0.263,0.351,5.654 35 | pythia-70m,wikidata_occupation.pyth.128.30000,writer,4,171,0.162,0.271,4.383 -------------------------------------------------------------------------------- /load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import datasets 5 | from transformer_lens import HookedTransformer, utils 6 | from config import FeatureDatasetConfig 7 | 8 | 9 | def load_model(model_name="gpt2-small", device=None): 10 | if device is None: 11 | device = "cuda" if torch.cuda.is_available() else "cpu" 12 | if model_name.startswith('pythia'): 13 | try: 14 | model = HookedTransformer.from_pretrained( 15 | model_name + '-70m', device='cpu') 16 | except ValueError: 17 | print(f'No {model_name}-v0 available') 18 | model = HookedTransformer.from_pretrained(model_name, device='cpu') 19 | model.eval() 20 | torch.set_grad_enabled(False) 21 | if model.cfg.device != device: 22 | try: 23 | model.to(device) 24 | except RuntimeError: 25 | print( 26 | f"WARNING: model is too large to fit on {device}. Falling back to CPU") 27 | model.to('cpu') 28 | 29 | return model 30 | 31 | 32 | def load_feature_dataset(name, n=-1): 33 | path = os.path.join(os.environ.get( 34 | 'FEATURE_DATASET_DIR', 'feature_datasets'), name) 35 | if n > 0: 36 | return datasets.load_from_disk(path).select(range(n)) 37 | else: 38 | return datasets.load_from_disk(path) 39 | 40 | 41 | def load_raw_dataset(path, n_seqs=-1): 42 | save_path = os.path.join(os.environ['HF_DATASETS_CACHE'], path) 43 | dataset = datasets.load_from_disk(save_path) 44 | if n_seqs > 0: 45 | dataset = dataset.select(range(n_seqs)) 46 | return dataset 47 | -------------------------------------------------------------------------------- /probing_datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wesg52/sparse-probing-paper/a610e102c6e25a6ef9cc16c3a2abb736aa90849b/probing_datasets/__init__.py -------------------------------------------------------------------------------- /probing_datasets/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import datasets 4 | from scipy import sparse 5 | 6 | 7 | class FeatureDataset: 8 | 9 | def __init__(self, name): 10 | self.name = name 11 | 12 | def load(self, dataset_config): 13 | file_loc = os.path.join( 14 | os.getenv('FEATURE_DATASET_DIR', 'feature_datasets'), 15 | dataset_config.make_dir_name() 16 | ) 17 | return datasets.load_from_disk(file_loc) 18 | 19 | def save(self, dataset_config, feature_dataset): 20 | file_loc = os.path.join( 21 | os.getenv('FEATURE_DATASET_DIR', 'feature_datasets'), 22 | dataset_config.make_dir_name() 23 | ) 24 | os.makedirs(file_loc, exist_ok=True) 25 | feature_dataset.save_to_disk(file_loc) 26 | 27 | 28 | def get_char_to_tok_map(decoded_seq, sequence_indices): 29 | sequence_indices = np.concatenate([np.array([0]), sequence_indices]) 30 | char_to_tok = np.zeros(len(decoded_seq), dtype=int) 31 | for i, (start, end) in enumerate(zip(sequence_indices[:-1], sequence_indices[1:])): 32 | char_to_tok[start:end] = i 33 | return char_to_tok 34 | 35 | 36 | def tokenize_and_concatenate_separate_subsequences( 37 | dataset, 38 | tokenizer, 39 | streaming=False, 40 | max_length=1024, 41 | column_name="text", 42 | add_bos_token=True, 43 | random_subsequence=True, 44 | num_proc=10, 45 | ): 46 | """Helper function to process text 47 | Args: 48 | dataset (Dataset): The dataset to tokenize, assumed to be a HuggingFace text dataset. 49 | tokenizer (AutoTokenizer): The tokenizer. Assumed to have a bos_token_id and an eos_token_id. 50 | streaming (bool, optional): Whether the dataset is being streamed. If True, avoids using parallelism. Defaults to False. 51 | max_length (int, optional): The length of the context window of the sequence. Defaults to 1024. 52 | column_name (str, optional): The name of the text column in the dataset. Defaults to 'text'. 53 | add_bos_token (bool, optional): . Defaults to True. 54 | Returns: 55 | Dataset: Returns the tokenized dataset, as a dataset of tensors, with a single column called "tokens" 56 | """ 57 | def subsample_sequence(tokens, max_length, pad_id, add_bos_token=True, random_subsequence=True): 58 | """Subsample a sequence to a maximum length, padding if necessary.""" 59 | if add_bos_token: 60 | seq_len = max_length - 1 61 | else: 62 | seq_len = max_length 63 | 64 | tok_seqs = [] 65 | for tok_seq in tokens: 66 | if len(tok_seq) > seq_len: 67 | # Take random seq_len length subsequence 68 | start = np.random.randint( 69 | 0, len(tok_seq) - seq_len) if random_subsequence else 0 70 | tok_seq = tok_seq[start: start + seq_len] 71 | else: 72 | # Pad to seq_len 73 | tok_seq = np.pad( 74 | tok_seq, (0, seq_len - len(tok_seq)), constant_values=pad_id) 75 | tok_seqs.append(tok_seq) 76 | 77 | tok_arr = np.array(tok_seqs) 78 | if add_bos_token: 79 | bos_arr = np.ones((len(tok_seqs), 1), 80 | dtype=np.int32) * tokenizer.bos_token_id 81 | tok_arr = np.hstack([bos_arr, tok_arr]) 82 | 83 | return tok_arr.astype(np.int32).tolist() 84 | 85 | if tokenizer.pad_token is None: 86 | print('WARNING: model does not have a pad token') 87 | tokenizer.add_special_tokens({"pad_token": ""}) 88 | # Define the length to chop things up into - leaving space for a bos_token if required 89 | 90 | tokenized_dataset = dataset.map( 91 | lambda x: {'all_tokens': tokenizer(x[column_name])['input_ids']}, 92 | batched=True, 93 | num_proc=(num_proc if not streaming else None), 94 | ) 95 | print('Finished tokenizing dataset, beginning to subsample sequences') 96 | subsampled_sequences = subsample_sequence( 97 | tokenized_dataset['all_tokens'], max_length, tokenizer.pad_token_id, add_bos_token, random_subsequence) 98 | tokenized_dataset = tokenized_dataset.add_column( 99 | 'tokens', subsampled_sequences) 100 | 101 | tokenized_dataset.set_format(type="torch") 102 | 103 | return tokenized_dataset 104 | -------------------------------------------------------------------------------- /probing_datasets/counterfact.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | import numpy as np 5 | from datasets.arrow_dataset import Dataset 6 | from transformers import AutoTokenizer 7 | 8 | from probing_datasets.common import FeatureDataset 9 | from config import FeatureDatasetConfig 10 | 11 | 12 | class CounterfactFeatureDataset(FeatureDataset): 13 | 14 | def __init__(self): 15 | pass 16 | 17 | def prepare_dataset(self, exp_cfg): 18 | ''' 19 | Return valid indices and classes for the feature dataset. 20 | ''' 21 | dataset = self.load(exp_cfg.dataset_cfg) 22 | 23 | # index is position within flattened (n_seq x seq_len,) array 24 | feature_indices = dataset['target_end'] - 1 25 | feature_indices += torch.arange(len(feature_indices)) * len(dataset[0]['tokens']) 26 | 27 | # classes to {-1, +1} 28 | feature_classes = np.full(len(dataset), -1) 29 | feature_classes[dataset['text_true']] = 1 30 | 31 | feature_datasets = {'text_true': (feature_indices, feature_classes)} 32 | return dataset, feature_datasets 33 | 34 | def make( 35 | self, 36 | dataset_config: FeatureDatasetConfig, 37 | args: dict, # command line arguments from make_feature_datasets 38 | raw_dataset: Dataset, 39 | tokenizer: AutoTokenizer, 40 | cache=True, 41 | ) -> Dataset: 42 | ''' 43 | Returns feature_dataset with columns: 44 | text: raw strings 45 | tokens: tokenized strings of consistent length 46 | text_true: boolean class label 47 | subject_start: index of the start of the subject 48 | subject_end: index of the end of the subject 49 | target_start: index of the start of the target 50 | target_end: index of the end of the target 51 | relation_id: relation id 52 | ''' 53 | 54 | # create a positive and negative example for each in the batch 55 | def create_pos_neg(batch): 56 | text = [(p + tt, p + tf) for p, tt, tf in zip(batch['prompt'], batch['target_true'], batch['target_false'])] 57 | text_flat = [t for pair in text for t in pair] 58 | return { 59 | 'text': text_flat, 60 | 'text_true': len(batch['prompt']) * [True, False], 61 | 'relation_prefix': [r for r in batch['relation_prefix'] for _ in range(2)], 62 | 'subject': [s for s in batch['subject'] for _ in range(2)], 63 | 'prompt': [p for p in batch['prompt'] for _ in range(2)], 64 | 'relation_id': [rid for rid in batch['relation_id'] for _ in range(2)], 65 | } 66 | remove_columns = ['relation', 'relation_suffix', 'target_true_id', 'target_false_id', 67 | 'target_true', 'target_false', 'subject'] 68 | feature_dataset = raw_dataset.map(create_pos_neg, batched=True, remove_columns=remove_columns) 69 | 70 | # tokenize each example, storing the indices for the subject and target 71 | def tokenize(example): 72 | # get tokens 73 | seq_len = dataset_config.ctx_len - 1 if args['add_bos'] else dataset_config.ctx_len 74 | all_tokens = tokenizer(example['text'], max_length=seq_len, truncation=True, padding='max_length')['input_ids'] 75 | tokens = [tokenizer.bos_token_id] + all_tokens if args['add_bos'] else all_tokens 76 | 77 | # get the index of the subject and target 78 | # NOTE: this isn't the most efficient but it doesn't need to be and should be pretty robust 79 | token_strs = tokenizer.batch_decode(tokens, clean_up_tokenization_spaces=True) 80 | sentence = "" 81 | index_counter = 0 82 | search_str = tokenizer.bos_token + example['relation_prefix'] if args['add_bos'] else example['relation_prefix'] 83 | for i, token_str in enumerate(token_strs + ['']): 84 | if index_counter == 0: # start of subject 85 | if search_str in sentence: 86 | example['subject_start'] = i 87 | search_str += example['subject'] 88 | index_counter += 1 89 | if index_counter == 1: # (1 after) end of subject 90 | if search_str in sentence: 91 | example['subject_end'] = i 92 | search_str = example['prompt'] 93 | index_counter += 1 94 | if index_counter == 2: # start of target 95 | if search_str in sentence: 96 | example['target_start'] = i 97 | search_str = example['text'] 98 | index_counter += 1 99 | if index_counter == 3: # (1 after) end of target 100 | if search_str in sentence: 101 | example['target_end'] = i 102 | break 103 | sentence += token_str 104 | 105 | # print(example['text'], token_strs[probe_indices[0]:probe_indices[1]], token_strs[probe_indices[2]:probe_indices[3]], probe_indices) 106 | 107 | example['tokens'] = tokens 108 | # valid_indicies_min = 1 if args['add_bos'] else 0 109 | # valid_indicies_max = min(len(tokenizer(example['text'])['input_ids']) + valid_indicies_min, dataset_config.ctx_len) 110 | # example['valid_indices'] = list(range(valid_indicies_min, valid_indicies_max)) 111 | 112 | return example 113 | 114 | feature_dataset = feature_dataset.map(tokenize) 115 | 116 | # this is a bit messy but sometimes we have a problem with the tokenization decoding (e.g. uses weird characters) and we just skip the example 117 | probe_indices_keys = ['subject_start', 'subject_end', 'target_start', 'target_end'] 118 | feature_dataset = feature_dataset.filter(lambda example: set(probe_indices_keys).issubset(set(example.keys()))) 119 | 120 | # clean up dataset columns and datatypes 121 | if dataset_config.n_sequences > 0: 122 | feature_dataset = feature_dataset.select(range(dataset_config.n_sequences)) 123 | 124 | feature_dataset = feature_dataset.remove_columns(['relation_prefix', 'prompt', 'subject']) 125 | feature_dataset.set_format(type="torch", columns=['tokens'] + probe_indices_keys, 126 | output_all_columns=True) 127 | 128 | if cache: 129 | self.save(dataset_config, feature_dataset) 130 | 131 | return feature_dataset 132 | -------------------------------------------------------------------------------- /probing_datasets/distribution_id.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datasets 3 | import torch 4 | import numpy as np 5 | from .common import * 6 | 7 | DATASET_SPLITS = { 8 | 'Wikipedia (en)': 'wikipedia', 9 | 'PubMed Abstracts': 'pubmed_abstracts', 10 | 'StackExchange': 'stack_exchange', 11 | 'Github': 'github', 12 | 'ArXiv': 'arxiv', 13 | 'USPTO Backgrounds': 'uspto', 14 | 'FreeLaw': 'freelaw', 15 | 'HackerNews': 'hackernews', 16 | 'Enron Emails': 'enron' 17 | } 18 | 19 | 20 | class DataDistributionIDFeatureDataset(FeatureDataset): 21 | 22 | def __init__(self, name, data_splits): 23 | self.name = name 24 | self.data_splits = data_splits 25 | 26 | def prepare_dataset(self, exp_cfg): 27 | """Convert categorial labels into binary labels for each data_splits. 28 | 29 | Returns: tokenized_dataset, feature_datasets with structure: 30 | {data_splits: (indices, classes)}. 31 | 32 | ...except when exp_cfg.aggregation is not None. Then 33 | indices are valid_index mask with class per row.""" 34 | dataset = self.load(exp_cfg.dataset_cfg) 35 | _, n = dataset['probe_indices'].shape 36 | 37 | feature_datasets = {k: {'indices': [], 'classes': []} 38 | for k in self.data_splits.values()} 39 | 40 | dataset_distribution = np.array(dataset['distribution']) 41 | 42 | if exp_cfg.activation_aggregation is None: 43 | valid_index_mask = torch.zeros_like(dataset['tokens']) 44 | for ix, valid_seq_indices in enumerate(dataset['probe_indices']): 45 | valid_index_mask[ix, valid_seq_indices] = 1 46 | valid_indices = np.where(valid_index_mask.flatten())[0] 47 | 48 | extended_dataset_distribution = dataset_distribution[:, None].repeat( 49 | n, axis=1).flatten() 50 | for ix, split in enumerate(self.data_splits.values()): 51 | split_label = (extended_dataset_distribution == 52 | split).astype(int) * 2 - 1 53 | # same index_mask for all splits 54 | feature_datasets[split] = (valid_indices, split_label) 55 | return dataset, feature_datasets 56 | 57 | else: # each feature dataset is (size(n x ctx_len) index_mask, size(n) class) 58 | valid_index_mask = torch.zeros_like(dataset['tokens']) 59 | for ix, valid_seq_indices in enumerate(dataset['valid_indices']): 60 | valid_index_mask[ix, valid_seq_indices] = 1 61 | 62 | for ix, split in enumerate(self.data_splits.values()): 63 | split_label = (dataset_distribution == 64 | split).astype(int) * 2 - 1 65 | # same index_mask for all languages 66 | feature_datasets[split] = (valid_index_mask, split_label) 67 | return dataset, feature_datasets 68 | 69 | def make(self, dataset_config, args, raw_dataset, tokenizer, cache=True): 70 | 71 | n_probe_tokens = args.get('lang_id_n_tokens', 2) 72 | ignore_k = args.get('ignore_first_k', 25) 73 | 74 | tokenized_dataset = tokenize_and_concatenate_separate_subsequences( 75 | raw_dataset, 76 | tokenizer, 77 | max_length=dataset_config.ctx_len, 78 | add_bos_token=args.get('add_bos_token', True), 79 | ) 80 | distribution_ids = [self.data_splits[meta['pile_set_name']] 81 | for meta in raw_dataset['meta']] 82 | feature_dataset = tokenized_dataset.add_column( 83 | 'distribution', distribution_ids) 84 | # filter out too short sequences. 85 | feature_dataset = feature_dataset.filter( 86 | lambda x: len(x['all_tokens']) > ignore_k + n_probe_tokens + 1) 87 | 88 | all_probe_indices = [] 89 | all_valid_indices = [] 90 | 91 | END_TOKENS = torch.tensor( 92 | [tokenizer.eos_token_id, tokenizer.pad_token_id], dtype=torch.long) 93 | for i in range(len(feature_dataset)): 94 | # determine last valid token. 95 | if feature_dataset[i]['tokens'][-1] not in END_TOKENS: 96 | eos = len(feature_dataset[i]['tokens']) - 1 97 | else: 98 | eos = torch.isin( 99 | feature_dataset[i]['tokens'], 100 | END_TOKENS 101 | ).nonzero()[0].item() 102 | 103 | valid_indices = np.arange(ignore_k, eos) 104 | if len(valid_indices) > n_probe_tokens: 105 | probe_indices = sorted(np.random.choice( 106 | valid_indices, size=n_probe_tokens, replace=False).tolist()) 107 | else: 108 | probe_indices = valid_indices.tolist() 109 | 110 | all_probe_indices.append(probe_indices) 111 | all_valid_indices.append(valid_indices) 112 | 113 | # valid indices are required for aggregation. 114 | feature_dataset = feature_dataset.add_column( 115 | 'probe_indices', all_probe_indices) 116 | feature_dataset = feature_dataset.add_column( 117 | 'valid_indices', all_valid_indices) 118 | 119 | feature_dataset = feature_dataset.filter( 120 | lambda x: len(x['valid_indices']) > 0) 121 | 122 | feature_dataset.set_format('torch') 123 | 124 | if cache: 125 | self.save(dataset_config, feature_dataset) 126 | 127 | return feature_dataset 128 | 129 | 130 | DISPLAY_NAMES = { 131 | 'wikipedia': 'Wikipedia', 132 | 'pubmed_abstracts': 'PubMed', 133 | 'stack_exchange': 'StackEx', 134 | 'github': 'Github', 135 | 'arxiv': 'ArXiv', 136 | 'uspto': 'USPTO', 137 | 'freelaw': 'FreeLaw', 138 | 'hackernews': 'HackNews', 139 | 'enron': 'Enron' 140 | } 141 | -------------------------------------------------------------------------------- /probing_datasets/multitoken_supervised.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | from .common import * 4 | 5 | 6 | import regex 7 | 8 | # https://stackoverflow.com/questions/26385984/recursive-pattern-in-regex 9 | IN_QUOTES_REGEX = re.compile(r'[“"”](.*?)[”"]') 10 | 11 | IN_PARENTHESES_REGEX = regex.compile(r"\(((?>[^()]+|(?R))*)\)") 12 | 13 | IN_BRACKETS_REGEX = regex.compile(r"\[((?>[^\[\]]+|(?R))*)\]") 14 | 15 | IN_ANGLE_BRACKETS_REGEX = regex.compile(r"<((?>[^<>]+|(?R))*)>") 16 | 17 | IN_CURLY_BRACKETS_REGEX = regex.compile(r"{((?>[^{}]+|(?R))*)}") 18 | 19 | CONTAINMENT_REGEXES = [ 20 | ('in_quotes', IN_QUOTES_REGEX), 21 | ('in_parentheses', IN_PARENTHESES_REGEX), 22 | ('in_brackets', IN_BRACKETS_REGEX), 23 | ('in_angle_brackets', IN_ANGLE_BRACKETS_REGEX), 24 | ('in_curly_brackets', IN_CURLY_BRACKETS_REGEX) 25 | ] 26 | 27 | 28 | class ContainmentClassificationFeatureDataset(FeatureDataset): 29 | def __init__(self, name, containment_regexes): 30 | self.name = name 31 | self.containment_regexes = containment_regexes 32 | 33 | def make(self, dataset_config, args, raw_dataset, tokenizer, cache=True): 34 | pass 35 | 36 | def prepare_dataset(self, exp_cfg): 37 | pass 38 | 39 | 40 | # for containment regexes 41 | def containment_mapping(token_list, tokenizer, containment_regexes): 42 | vocab = tokenizer.get_vocab() 43 | ix2str = {v: tokenizer.decode(v) for v in vocab.values()} 44 | token_length = {k: len(v) for k, v in ix2str.items()} 45 | 46 | get_tok_len = np.vectorize(lambda x: token_length[x]) 47 | get_tok_str = np.vectorize(lambda x: ix2str[x]) 48 | 49 | regex_matches = {name: [] for name, _ in containment_regexes} 50 | for i in range(len(token_list)): 51 | token_seq = token_list[i] 52 | decoded_seq = ''.join(get_tok_str(token_seq)) 53 | tok_lens = get_tok_len(token_seq) 54 | sequence_indices = np.cumsum(tok_lens) 55 | 56 | char2tok = get_char_to_tok_map(decoded_seq, sequence_indices) 57 | 58 | for name, reg in containment_regexes: 59 | feature_char_spans = [m.span() for m in reg.finditer(decoded_seq)] 60 | valid_spans = [ 61 | span for span in feature_char_spans if span[1] - span[0] > 2] 62 | token_spans = [(char2tok[start]+1, char2tok[end-1]-1) 63 | for start, end in valid_spans] 64 | regex_matches[name].append(token_spans) 65 | return regex_matches 66 | -------------------------------------------------------------------------------- /probing_datasets/neuron_stimulus.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | from transformer_lens import utils 3 | import numpy as np 4 | import torch 5 | from datasets import Dataset 6 | 7 | # token_ix: [list of prefixes to create classes for] 8 | # each list containts a set of prefixes 9 | # where a prefix is a tuple of tokens 10 | PYTHIA_70M_L1_N111_STIMULI = { 11 | 12299: [[(13804,), (9432,)]], # e.g. {Har, har}|vard 12 | 35476: [[(43950, 762, 253), (3567, 762, 253)], [(2256, 281, 253)]], # Boost 13 | 20740: [[(2058, 253, 5403), (3404, 253, 5403)]], # census 14 | 26268: [[(15, 12332, 9824)], [(32170, 9824)]], # Chain 15 | 412: [[(47678, ), (14029, )]], # op 16 | 20000: [[(38476, 5625)]], # Peace 17 | 3621: [[(47694, )]], # lease 18 | 13606: [[(5625, 330)]], # oven 19 | 14894: [[(12602, 273, 15123)]], # hma 20 | 2616: [[(4335, )], [(29602, )]], # factors 21 | 7736: [[(21698, )]], # pogenic 22 | 39098: [[(4146, 12761), (14594, 12761)]], # systems 23 | 4412: [[(11586, 2077)]], # District 24 | 15353: [[(36642, )]], # gate 25 | 11845: [[(19256, 38056)]], # iallance 26 | 4694: [[(47678, ), (14029, )]], # vision 27 | 20310: [[(5625, 10518), (22817, 10518)]], # Mach 28 | 6875: [[(7671, ), (31351, )]], # Science 29 | 17629: [[(6399, )]], # ograms 30 | 19934: [[(35654, 15)], [(2700, 15)]], # apple 31 | 48862: [[(45590, 64)]], # AUX 32 | 7404: [[(749, ), (2377, )]], # process 33 | 35437: [[(22817, 13940), (5625, 13940)]], # communication 34 | 7662: [[(2359, 412), (22468, 412)]], # roduction 35 | 25837: [[(681, 16)]], # fw 36 | 16240: [[(21034, )], [(20709, )]], # ington 37 | } 38 | 39 | 40 | class NeuronStimulusFeatureDataset(FeatureDataset): 41 | def __init__(self, name, stimuli): 42 | self.name = name 43 | self.stimuli = stimuli 44 | 45 | def make(self, dataset_config, args, raw_dataset, tokenizer, cache=True): 46 | if 'tokens' in raw_dataset.column_names: # already tokenized 47 | token_vector = raw_dataset['tokens'].flatten().numpy() 48 | else: 49 | tokenized_ds = utils.tokenize_and_concatenate( 50 | raw_dataset, tokenizer) 51 | token_vector = tokenized_ds['tokens'].flatten().numpy() 52 | 53 | stimulus_datasets = [] 54 | for probe_token, stimulus_class in self.stimuli.items(): 55 | # just take first class for simplicity (should be the strongest activating) 56 | stimulus = stimulus_class[0] 57 | 58 | probe_token_indices = np.where(token_vector == probe_token)[0] 59 | 60 | valid_stimulus_indices = [] 61 | for ngram_prefix in stimulus: 62 | probe_indices_with_stimulus_prefix = probe_token_indices 63 | for ix, t in enumerate(ngram_prefix[::-1]): 64 | offset = ix + 1 65 | probe_indices_with_correct_stimuli_prefix = np.where( 66 | token_vector[probe_indices_with_stimulus_prefix - offset] == t)[0] 67 | probe_indices_with_stimulus_prefix = probe_indices_with_stimulus_prefix[ 68 | probe_indices_with_correct_stimuli_prefix] 69 | valid_stimulus_indices.append( 70 | probe_indices_with_stimulus_prefix) 71 | valid_stimulus_indices = np.concatenate(valid_stimulus_indices) 72 | valid_negative_stimulus_indices = np.setdiff1d( 73 | probe_token_indices, valid_stimulus_indices) 74 | 75 | target_n_positive = min(500, len(valid_stimulus_indices)) 76 | target_n_negative = min(2000, len(valid_negative_stimulus_indices)) 77 | ctx_len = 32 78 | 79 | positive_indices = np.sort(np.random.choice( 80 | valid_stimulus_indices, target_n_positive, replace=False)) 81 | negative_indices = np.sort(np.random.choice( 82 | valid_negative_stimulus_indices, target_n_negative, replace=False)) 83 | len(positive_indices), len(negative_indices) 84 | 85 | positive_stimulus_token_tensor = np.vstack([ 86 | token_vector[ix+1-ctx_len: ix+1] for ix in positive_indices 87 | ]) 88 | negative_stimulus_token_tensor = np.vstack([ 89 | token_vector[ix+1-ctx_len: ix+1] for ix in negative_indices 90 | ]) 91 | stimulus_token_tensor = np.vstack([ 92 | positive_stimulus_token_tensor, negative_stimulus_token_tensor 93 | ]) 94 | 95 | token_name = tokenizer.decode(probe_token) 96 | feature_prefix = tokenizer.decode(list(stimulus[0])) 97 | feature_name = f'{feature_prefix}|{token_name}|' 98 | labels = ['positive' for _ in range(len(positive_indices))] \ 99 | + ['negative' for _ in range(len(negative_indices))] 100 | 101 | stimulus_ds = datasets.Dataset.from_dict({ 102 | 'tokens': stimulus_token_tensor, 103 | 'label': labels, 104 | 'feature_name': [feature_name for _ in range(len(labels))] 105 | }).shuffle() 106 | stimulus_datasets.append(stimulus_ds) 107 | print( 108 | f'Finished token {probe_token} {feature_name} with {len(positive_indices)} positive and {len(negative_indices)} negative stimuli') 109 | 110 | neuron_stimulus_dataset = datasets.concatenate_datasets( 111 | stimulus_datasets) 112 | neuron_stimulus_dataset.set_format(type="torch") 113 | 114 | target_n_positive = min(500, len(valid_stimulus_indices)) 115 | target_n_negative = min(2000, len(valid_negative_stimulus_indices)) 116 | ctx_len = args.get('ctx_len', 32) 117 | 118 | if cache: 119 | self.save(dataset_config, neuron_stimulus_dataset) 120 | 121 | return neuron_stimulus_dataset 122 | 123 | def prepare_dataset(self, exp_cfg): 124 | raise NotImplementedError( 125 | 'Currently neuron stimulus is only for activations') 126 | -------------------------------------------------------------------------------- /probing_datasets/ngrams.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | from transformer_lens import utils 3 | import numpy as np 4 | 5 | COMPOUND_WORDS = [ 6 | ('high', 'school'), 7 | ('living', 'room'), 8 | ('social', 'security'), 9 | ('credit', 'card'), 10 | ('blood', 'pressure'), 11 | ('prime', 'factors'), 12 | ('social', 'media'), 13 | ('gene', 'expression'), 14 | ('control', 'group'), 15 | ('magnetic', 'field'), 16 | ('cell', 'lines'), 17 | ('trial', 'court'), 18 | ('second', 'derivative'), 19 | ('north', 'america'), 20 | ('human', 'rights'), 21 | ('side', 'effects'), 22 | ('public', 'health'), 23 | ('federal', 'government'), 24 | ('third', 'party'), 25 | ('clinical', 'trials'), 26 | ('mental', 'health'), 27 | ] 28 | 29 | 30 | class BigramFeatureDataset(FeatureDataset): 31 | def __init__(self, name, features): 32 | self.name = name 33 | self.features = features 34 | 35 | def make(self, dataset_config, args, raw_dataset, tokenizer, cache=True): 36 | tokenized_ds = utils.tokenize_and_concatenate(raw_dataset, tokenizer) 37 | all_tokens = tokenized_ds['tokens'][:, 1:].flatten().numpy() 38 | decoded_vocab = { 39 | tokenizer.decode(tix): tix 40 | for tix in tokenizer.get_vocab().values() 41 | } 42 | 43 | dataset_size = args.get('dataset_size', 8_000) 44 | target_positive_fraction = args.get('target_positive_fraction', 0.2) 45 | ctx_len = args.get('seq_len', 24) 46 | 47 | target_positive = int(dataset_size * target_positive_fraction) 48 | target_negative = dataset_size - target_positive 49 | 50 | bigram_datasets = [] 51 | for first, second in self.features: 52 | compound_first_tokens = set( 53 | [v for k, v in decoded_vocab.items() if k.lower().strip() == first]) 54 | compound_second_tokens = set( 55 | [v for k, v in decoded_vocab.items() if k.lower().strip() == second]) 56 | 57 | first_indicator = np.isin(all_tokens, list(compound_first_tokens)) 58 | second_indicator = np.isin( 59 | all_tokens, list(compound_second_tokens)) 60 | 61 | first_occurences = np.where(first_indicator)[0] 62 | second_occurences = np.where(second_indicator)[0] 63 | 64 | bigram_indicator = np.isin( 65 | all_tokens[second_occurences - 1], list(compound_first_tokens)) 66 | bigram_occurences = second_occurences[ 67 | np.where(bigram_indicator)[0]] 68 | not_first_and_second_occurences = second_occurences[ 69 | np.where(~bigram_indicator)[0]] 70 | first_and_not_second_occurences = first_occurences[np.where( 71 | ~second_indicator[first_occurences + 1])[0]] + 1 72 | 73 | n_pos = min(len(bigram_occurences), target_positive) 74 | 75 | pos_ixs = np.random.choice(bigram_occurences, n_pos, replace=False) 76 | not_first_neg_ixs = np.random.choice( 77 | not_first_and_second_occurences, int(target_negative / 2), replace=False) 78 | not_second_neg_ixs = np.random.choice( 79 | first_and_not_second_occurences, int(target_negative / 2), replace=False) 80 | 81 | all_ixs = np.sort(np.concatenate( 82 | [pos_ixs, not_first_neg_ixs, not_second_neg_ixs])) 83 | 84 | token_tensor = np.vstack([ 85 | all_tokens[ix+1-ctx_len: ix+1] for ix in all_ixs 86 | ]) 87 | 88 | feature_name = [f'{first}-{second}' for _ in range(len(all_ixs))] 89 | 90 | pos_set = set(pos_ixs) 91 | not_first_set = set(not_first_neg_ixs) 92 | 93 | label = [ 94 | 'bigram' if ix in pos_set 95 | else ('missing_first' if ix in not_first_set else 'missing_second') 96 | for ix in all_ixs 97 | ] 98 | 99 | ds = datasets.Dataset.from_dict({ 100 | 'tokens': token_tensor, 101 | 'label': label, 102 | 'feature_name': feature_name 103 | }) 104 | ds.set_format(type="torch") 105 | 106 | bigram_datasets.append(ds) 107 | print(f'Finished processing {first}-{second}') 108 | 109 | full_ds = datasets.concatenate_datasets(bigram_datasets) 110 | 111 | if cache: 112 | self.save(dataset_config, full_ds) 113 | 114 | return full_ds 115 | 116 | def prepare_dataset(self, exp_cfg): 117 | full_ds = self.load(exp_cfg.dataset_cfg) 118 | ctx_len = exp_cfg.dataset_cfg.ctx_len 119 | offset = -2 if exp_cfg.probe_next_token_feature else -1 120 | feature_datasets = {} 121 | for first, second in self.features: 122 | feature_name = f'{first}-{second}' 123 | feature_ixs = np.array(full_ds['feature_name']) == feature_name 124 | feature_subset = full_ds.select(np.where(feature_ixs)[0]) 125 | 126 | label_arr = np.array(feature_subset['label']) 127 | label = (label_arr == 'bigram').astype(int) * 2 - 1 128 | indices = np.arange(1, len(feature_subset) + 1) * ctx_len + offset 129 | # assumes features are contiguous 130 | subset_offset = np.min(np.where(feature_ixs)[0]) 131 | indices += subset_offset * ctx_len 132 | 133 | feature_datasets[feature_name] = (indices, label) 134 | 135 | return full_ds, feature_datasets 136 | -------------------------------------------------------------------------------- /probing_datasets/pile_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datasets 3 | import torch 4 | import numpy as np 5 | from .common import * 6 | from transformer_lens.utils import tokenize_and_concatenate 7 | 8 | 9 | class PileTestSplitFeatureDataset(FeatureDataset): 10 | """ 11 | A large tokenized text dataset corresponding to the Pile test set. 12 | 13 | Tokenized to have minimal padding but where all concatenated sequences 14 | come from the same sub distribution (with the label). 15 | """ 16 | 17 | def __init__(self, name): 18 | self.name = name 19 | 20 | def prepare_dataset(self, exp_cfg): 21 | raise NotImplementedError( 22 | "PileTestSplitFeatureDataset is not meant to be used for probing, only activation statistics." 23 | ) 24 | 25 | def make(self, dataset_config, args, raw_dataset, tokenizer, cache=True): 26 | 27 | data_split = np.array([m['pile_set_name'] 28 | for m in raw_dataset['meta']]) 29 | splits = np.unique(data_split) 30 | 31 | sub_datasets = [] 32 | for split in splits: 33 | print(f"Tokenizing {split}...") 34 | sub_dataset_indices = np.where(data_split == split)[0] 35 | sub_dataset = raw_dataset.select(sub_dataset_indices) 36 | 37 | tokenized_sub_dataset = tokenize_and_concatenate( 38 | sub_dataset, 39 | tokenizer, 40 | max_length=dataset_config.ctx_len, 41 | add_bos_token=args.get('add_bos_token', True), 42 | ) 43 | tokenized_sub_dataset = tokenized_sub_dataset.add_column( 44 | 'distribution', 45 | [split for _ in range(len(tokenized_sub_dataset))] 46 | ) 47 | sub_datasets.append(tokenized_sub_dataset) 48 | 49 | tokenized_dataset = datasets.concatenate_datasets(sub_datasets) 50 | 51 | if cache: 52 | self.save(dataset_config, tokenized_dataset) 53 | 54 | return tokenized_dataset 55 | -------------------------------------------------------------------------------- /probing_datasets/position.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datasets 3 | import einops 4 | import torch 5 | import numpy as np 6 | from .common import * 7 | from .pile_test import PileTestSplitFeatureDataset 8 | from transformer_lens.utils import tokenize_and_concatenate 9 | 10 | 11 | class PositionFeatureDataset(FeatureDataset): 12 | """ 13 | Pile test set with minimal padding and position columns. 14 | """ 15 | 16 | def __init__(self): 17 | self.name = 'position' 18 | 19 | def prepare_dataset(self, exp_cfg): 20 | dataset = self.load(exp_cfg.dataset_cfg) 21 | index_mask = dataset['index_mask'] 22 | probe_indices = np.where(index_mask.flatten())[0] 23 | 24 | pos_cols = ['abs_pos', 'norm_abs_pos', 25 | 'rel_pos', 'norm_rel_pos', 'log_pos'] 26 | feature_datasets = { 27 | col: (probe_indices, dataset[col][index_mask].numpy()) 28 | for col in pos_cols 29 | } 30 | return dataset, feature_datasets 31 | 32 | def make(self, dataset_config, args, raw_dataset, tokenizer, cache=True): 33 | 34 | # reuse PileTestSplitFeatureDataset make() 35 | tokenized_dataset = PileTestSplitFeatureDataset('').make( 36 | dataset_config, args, raw_dataset, tokenizer, cache=False) 37 | tokenized_dataset = tokenized_dataset.select( 38 | range(args.get('n_seqs', 10_000))) 39 | 40 | n, d, = tokenized_dataset['tokens'].shape 41 | 42 | abs_pos = np.arange(d).astype(np.float32) 43 | norm_abs_pos = abs_pos / abs_pos.std() 44 | 45 | rel_pos = abs_pos - abs_pos.mean() 46 | norm_rel_pos = rel_pos / rel_pos.std() 47 | 48 | log_pos = np.log2(abs_pos + 1) 49 | 50 | # add position columns to dataset 51 | tokenized_dataset = tokenized_dataset.add_column( 52 | 'abs_pos', [abs_pos for _ in range(n)]) 53 | tokenized_dataset = tokenized_dataset.add_column( 54 | 'norm_abs_pos', [norm_abs_pos for _ in range(n)]) 55 | tokenized_dataset = tokenized_dataset.add_column( 56 | 'rel_pos', [rel_pos for _ in range(n)]) 57 | tokenized_dataset = tokenized_dataset.add_column( 58 | 'norm_rel_pos', [norm_rel_pos for _ in range(n)]) 59 | tokenized_dataset = tokenized_dataset.add_column( 60 | 'log_pos', [log_pos for _ in range(n)]) 61 | 62 | s = args.get('dataset_size', 50_000) 63 | probe_indices = np.random.choice(n*d, s, replace=False) 64 | index_mask = np.zeros(n*d, dtype=bool) 65 | index_mask[probe_indices] = True 66 | index_mask = index_mask.reshape((n, d)) 67 | 68 | tokenized_dataset = tokenized_dataset.add_column( 69 | 'index_mask', [index_mask[i] for i in range(n)]) 70 | 71 | if cache: 72 | self.save(dataset_config, tokenized_dataset) 73 | 74 | return tokenized_dataset 75 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gurobipy >= 10 2 | transformer-lens 3 | torch 4 | torchtext 5 | numpy 6 | pandas 7 | tqdm 8 | matplotlib 9 | scikit-learn 10 | scipy 11 | plotly 12 | ipython 13 | jupyter 14 | pylint 15 | autopep8 16 | seaborn 17 | notebook 18 | conllu 19 | circuitsvis 20 | zstandard 21 | spacy[apple] 22 | spacy-alignments 23 | sparqlwrapper 24 | google-re2 25 | numba 26 | -------------------------------------------------------------------------------- /save_weight_statistics.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import numpy as np 5 | from load import load_model 6 | 7 | 8 | def compute_and_save_weight_statistics(args): 9 | model = load_model(args.model_name, device='cpu') 10 | 11 | in_norm = model.W_in.norm(dim=1).numpy() 12 | in_bias = model.b_in.numpy() 13 | out_norm = model.W_out.norm(dim=-1).numpy() 14 | out_bias = model.b_out.numpy() 15 | cos = torch.nn.CosineSimilarity()(model.W_in, torch.swapaxes(model.W_out, 1, 2)) 16 | 17 | n_layers, n_neurons = in_norm.shape 18 | statistics = np.zeros((5, n_layers, n_neurons)) 19 | statistics[0] = in_norm 20 | statistics[1] = in_bias 21 | statistics[2] = out_norm 22 | statistics[3, :, :len(out_bias[0])] = out_bias 23 | statistics[4] = cos 24 | 25 | save_dir = os.path.join( 26 | os.environ.get('RESULTS_DIR', 'results'), 27 | 'weight_statistics' 28 | ) 29 | os.makedirs(save_dir, exist_ok=True) 30 | save_file = os.path.join(save_dir, f'{args.model_name}.npy') 31 | np.save(save_file, statistics) 32 | 33 | 34 | def load_weight_statistics(model_name, save_dir=None): 35 | if save_dir is None: 36 | save_dir = os.path.join( 37 | os.environ.get('RESULTS_DIR', 'results'), 38 | 'weight_statistics' 39 | ) 40 | stats = np.load(os.path.join(save_dir, f'{model_name}.npy')) 41 | _, _, n_neurons = stats.shape 42 | return { 43 | 'in_norm': stats[0], 44 | 'in_bias': stats[1], 45 | 'out_norm': stats[2], 46 | 'out_bias': stats[3, :, :n_neurons//4], 47 | 'cos': stats[4] 48 | } 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument('--model_name', type=str, help='Model name') 54 | args = parser.parse_args() 55 | compute_and_save_weight_statistics(args) 56 | -------------------------------------------------------------------------------- /scripts/activations/layer_1_20_50_80_percentile.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o log/%j-percentile_save 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | 6 | # setup env 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 21 | source /etc/profile 22 | 23 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 24 | 25 | for model in "${PYTHIA_MODELS[@]}" 26 | do 27 | python -u get_activations.py \ 28 | --experiment_name layer_1_percentile_activations \ 29 | --experiment_type activation_subset \ 30 | --feature_dataset pile_test.pyth.512.-1 \ 31 | --model "$model" \ 32 | --batch_size 32 \ 33 | --neuron_subset_file biasxnorm_p20_50_80.csv 34 | done 35 | 36 | -------------------------------------------------------------------------------- /scripts/activations/layer_1_neuron_0_act.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o log/%j-layer_1_neuron_0_save.log 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | 6 | # setup env 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 21 | source /etc/profile 22 | 23 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 24 | 25 | for model in "${PYTHIA_MODELS[@]}" 26 | do 27 | python get_activations.py \ 28 | --experiment_name layer_1_neuron_0_pile_test_activations \ 29 | --experiment_type activation_subset \ 30 | --feature_dataset pile_test.pyth.512.-1 \ 31 | --model "$model" \ 32 | --batch_size 16 \ 33 | --neuron_subset 1,0 34 | done 35 | 36 | -------------------------------------------------------------------------------- /scripts/activations/run_activation_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o log/activation_all.log-%j 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | 6 | # setup env 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 21 | source /etc/profile 22 | 23 | python get_activations.py \ 24 | --experiment_name pyth70m_n1_111_stimuli \ 25 | --experiment_type all_activations \ 26 | --feature_dataset neuron_stimulus.pyth.32.-1 \ 27 | --batch_size 64 \ 28 | --model pythia-70m \ 29 | --layers 1 \ 30 | --positions 31 31 | -------------------------------------------------------------------------------- /scripts/activations/run_activation_metrics.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o log/activation_metrics.log-%j 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | 6 | # setup env 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 21 | source /etc/profile 22 | 23 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b') 24 | 25 | for model in "${PYTHIA_MODELS[@]}" 26 | do 27 | python get_activations.py \ 28 | --experiment_name full_range \ 29 | --experiment_type full_activation_histogram \ 30 | --feature_dataset pile_test.pyth.512.-1 \ 31 | --n_bin 1000 --hist_max 100 --hist_min -100 \ 32 | --model "$model" 33 | done -------------------------------------------------------------------------------- /scripts/activations/save_all_ewt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o log/activation_all.log-%j 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | 6 | # setup env 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 21 | source /etc/profile 22 | 23 | python get_activations.py \ 24 | --experiment_name ewt_full \ 25 | --experiment_type all_activations \ 26 | --feature_dataset preprocessed_ewt_512.hf \ 27 | --batch_size 64 \ 28 | --model pythia-70m \ 29 | --flatten_and_ignore_padding 30 | 31 | 32 | python get_activations.py \ 33 | --experiment_name ewt_full \ 34 | --experiment_type all_activations \ 35 | --feature_dataset preprocessed_ewt_512.hf \ 36 | --batch_size 64 \ 37 | --model pythia-1b \ 38 | --flatten_and_ignore_padding 39 | -------------------------------------------------------------------------------- /scripts/activations/save_all_neurons_of_interest.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o log/%j-save_top_mono_neurons.log 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | 6 | # setup env 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 21 | source /etc/profile 22 | 23 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 24 | 25 | for model in "${PYTHIA_MODELS[@]}" 26 | do 27 | python get_activations.py \ 28 | --experiment_name top_mono_neurons \ 29 | --experiment_type activation_subset \ 30 | --feature_dataset programming_lang_id.pyth.512.-1 \ 31 | --model "$model" \ 32 | --neuron_subset_file top_mono_neurons.csv \ 33 | --auto_restrict_neuron_subset_file 34 | 35 | python get_activations.py \ 36 | --experiment_name top_mono_neurons \ 37 | --experiment_type activation_subset \ 38 | --feature_dataset distribution_id.pyth.512.-1 \ 39 | --model "$model" \ 40 | --neuron_subset_file top_mono_neurons.csv \ 41 | --auto_restrict_neuron_subset_file 42 | 43 | # python get_activations.py \ 44 | # --experiment_name top_mono_neurons \ 45 | # --experiment_type activation_subset \ 46 | # --feature_dataset natural_lang_id.pyth.512.-1 \ 47 | # --model "$model" \ 48 | # --neuron_subset_file top_mono_neurons.csv \ 49 | # --auto_restrict_neuron_subset_file \ 50 | # --skip_computing_token_summary_df 51 | 52 | python get_activations.py \ 53 | --experiment_name top_mono_neurons \ 54 | --experiment_type activation_subset \ 55 | --feature_dataset compound_words.pyth.24.-1 \ 56 | --model "$model" \ 57 | --neuron_subset_file top_mono_neurons.csv \ 58 | --auto_restrict_neuron_subset_file \ 59 | --skip_computing_token_summary_df 60 | 61 | python get_activations.py \ 62 | --experiment_name top_mono_neurons \ 63 | --experiment_type activation_subset \ 64 | --feature_dataset text_features.pyth.256.10000 \ 65 | --model "$model" \ 66 | --neuron_subset_file top_mono_neurons.csv \ 67 | 68 | python get_activations.py \ 69 | --experiment_name top_mono_neurons \ 70 | --experiment_type activation_subset \ 71 | --feature_dataset ewt.pyth.512.-1 \ 72 | --model "$model" \ 73 | --neuron_subset_file top_mono_neurons.csv \ 74 | --auto_restrict_neuron_subset_file \ 75 | --skip_computing_token_summary_df 76 | 77 | python get_activations.py \ 78 | --experiment_name top_mono_neurons \ 79 | --experiment_type activation_subset \ 80 | --feature_dataset latex.pyth.1024.-1 \ 81 | --model "$model" \ 82 | --batch_size 8 \ 83 | --neuron_subset_file top_mono_neurons.csv \ 84 | --auto_restrict_neuron_subset_file 85 | done -------------------------------------------------------------------------------- /scripts/activations/save_compound_words_subset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o log/%j-compound_subset.log 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | 6 | # setup env 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 21 | source /etc/profile 22 | 23 | python get_activations.py \ 24 | --experiment_name compound_superposition \ 25 | --experiment_type activation_subset \ 26 | --feature_dataset compound_words.pyth.24.-1 \ 27 | --model pythia-1b \ 28 | --batch_size 256 \ 29 | --save_by_neuron \ 30 | --skip_computing_token_summary_df \ 31 | --neuron_subset_file compound_superposition.csv 32 | 33 | 34 | python get_activations.py \ 35 | --experiment_name compound_superposition \ 36 | --experiment_type activation_subset \ 37 | --feature_dataset pile_test.pyth.512.-1 \ 38 | --model pythia-1b \ 39 | --batch_size 64 \ 40 | --save_by_neuron \ 41 | --skip_computing_token_summary_df \ 42 | --neuron_subset_file compound_superposition.csv 43 | 44 | 45 | # python get_activations.py \ 46 | # --experiment_name top_compound_words \ 47 | # --experiment_type activation_subset \ 48 | # --feature_dataset pile_test.pyth.512.-1 \ 49 | # --model pythia-160m \ 50 | # --batch_size 128 \ 51 | # --neuron_subset_file top_compound_words.csv 52 | 53 | 54 | # python get_activations.py \ 55 | # --experiment_name top_compound_words \ 56 | # --experiment_type activation_subset \ 57 | # --feature_dataset pile_test.pyth.512.-1 \ 58 | # --model pythia-2.8b \ 59 | # --batch_size 32 \ 60 | # --neuron_subset_file top_compound_words.csv 61 | -------------------------------------------------------------------------------- /scripts/activations/save_context_subset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o log/%j-all_context_subset.log 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | 6 | # setup env 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 21 | source /etc/profile 22 | 23 | python get_activations.py \ 24 | --experiment_name context_monosemantic \ 25 | --experiment_type activation_subset \ 26 | --feature_dataset natural_lang_id.pyth.512.-1 \ 27 | --model pythia-70m \ 28 | --batch_size 256 \ 29 | --neuron_subset_file monosemantic_language_neurons.csv 30 | 31 | 32 | python get_activations.py \ 33 | --experiment_name context_monosemantic \ 34 | --experiment_type activation_subset \ 35 | --feature_dataset programming_lang_id.pyth.512.-1 \ 36 | --model pythia-1b \ 37 | --batch_size 128 \ 38 | --neuron_subset_file monosemantic_code_neurons.csv 39 | 40 | 41 | python get_activations.py \ 42 | --experiment_name context_monosemantic \ 43 | --experiment_type activation_subset \ 44 | --feature_dataset distribution_id.pyth.512.-1 \ 45 | --model pythia-6.9b \ 46 | --batch_size 32 \ 47 | --neuron_subset_file monosemantic_distribution_neurons.csv 48 | 49 | 50 | python get_activations.py \ 51 | --experiment_name context_monosemantic \ 52 | --experiment_type activation_subset \ 53 | --feature_dataset pile_test.pyth.512.-1 \ 54 | --model pythia-70m \ 55 | --batch_size 256 \ 56 | --neuron_subset_file monosemantic_language_neurons.csv 57 | 58 | 59 | python get_activations.py \ 60 | --experiment_name context_monosemantic \ 61 | --experiment_type activation_subset \ 62 | --feature_dataset pile_test.pyth.512.-1 \ 63 | --model pythia-1b \ 64 | --batch_size 128 \ 65 | --neuron_subset_file monosemantic_code_neurons.csv 66 | -------------------------------------------------------------------------------- /scripts/activations/save_fact_neurons.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o log/%j-top_fact_neurons.log 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | 6 | # setup env 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 21 | source /etc/profile 22 | 23 | PYTHIA_MODELS=('pythia-6.9b') 24 | 25 | for model in "${PYTHIA_MODELS[@]}" 26 | do 27 | # python -u get_activations.py \ 28 | # --experiment_name top_fact_neurons \ 29 | # --experiment_type activation_subset \ 30 | # --feature_dataset wikidata_sorted_occupation_athlete.pyth.128.5000 \ 31 | # --model "$model" \ 32 | # --neuron_subset_file top_fact_neurons.csv \ 33 | # --auto_restrict_neuron_subset_file 34 | 35 | # python -u get_activations.py \ 36 | # --experiment_name top_fact_neurons \ 37 | # --experiment_type activation_subset \ 38 | # --feature_dataset wikidata_sorted_is_alive.pyth.128.6000 \ 39 | # --model "$model" \ 40 | # --neuron_subset_file top_fact_neurons.csv \ 41 | # --auto_restrict_neuron_subset_file 42 | 43 | # python -u get_activations.py \ 44 | # --experiment_name top_fact_neurons \ 45 | # --experiment_type activation_subset \ 46 | # --feature_dataset wikidata_sorted_occupation.pyth.128.6000 \ 47 | # --model "$model" \ 48 | # --neuron_subset_file top_fact_neurons.csv \ 49 | # --auto_restrict_neuron_subset_file 50 | 51 | # python -u get_activations.py \ 52 | # --experiment_name top_fact_neurons \ 53 | # --experiment_type activation_subset \ 54 | # --feature_dataset wikidata_sorted_political_party.pyth.128.3000 \ 55 | # --model "$model" \ 56 | # --neuron_subset_file top_fact_neurons.csv \ 57 | # --auto_restrict_neuron_subset_file 58 | 59 | # python -u get_activations.py \ 60 | # --experiment_name top_fact_neurons \ 61 | # --experiment_type activation_subset \ 62 | # --feature_dataset wikidata_sorted_sex_or_gender.pyth.128.6000 \ 63 | # --model "$model" \ 64 | # --neuron_subset_file top_fact_neurons.csv \ 65 | # --auto_restrict_neuron_subset_file 66 | 67 | python -u get_activations.py \ 68 | --experiment_name top_fact_neurons \ 69 | --experiment_type activation_subset \ 70 | --feature_dataset distribution_id.pyth.512.-1 \ 71 | --model "$model" \ 72 | --neuron_subset_file top_fact_neurons.csv 73 | done 74 | 75 | -------------------------------------------------------------------------------- /scripts/experiments/enumerate_monosemantic_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 6 5 | #SBATCH -o log/%j-enumerate_monosemantic.log 6 | #SBATCH -a 1-32 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | 27 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 28 | 29 | for model in "${PYTHIA_MODELS[@]}" 30 | do 31 | # python -u probing_experiment.py \ 32 | # --experiment_name monosemantic_sweep_all \ 33 | # --experiment_type enumerate_monosemantic \ 34 | # --model "$model" \ 35 | # --feature_dataset programming_lang_id.pyth.512.-1 \ 36 | # --activation_aggregation mean \ 37 | # --normalize_activations \ 38 | # --seed 42 \ 39 | # --save_features_together 40 | 41 | 42 | # python -u probing_experiment.py \ 43 | # --experiment_name monosemantic_sweep_all \ 44 | # --experiment_type enumerate_monosemantic \ 45 | # --model "$model" \ 46 | # --feature_dataset distribution_id.pyth.512.-1 \ 47 | # --activation_aggregation mean \ 48 | # --normalize_activations \ 49 | # --seed 42 \ 50 | # --save_features_together 51 | 52 | 53 | # python -u probing_experiment.py \ 54 | # --experiment_name monosemantic_sweep_all \ 55 | # --experiment_type enumerate_monosemantic \ 56 | # --model "$model" \ 57 | # --feature_dataset natural_lang_id.pyth.512.-1 \ 58 | # --activation_aggregation mean \ 59 | # --normalize_activations \ 60 | # --seed 42 \ 61 | # --save_features_together 62 | 63 | 64 | # python -u probing_experiment.py \ 65 | # --experiment_name monosemantic_sweep_all \ 66 | # --experiment_type enumerate_monosemantic \ 67 | # --model "$model" \ 68 | # --feature_dataset compound_words.pyth.24.-1 \ 69 | # --normalize_activations \ 70 | # --seed 42 \ 71 | # --save_features_together 72 | 73 | 74 | # python -u probing_experiment.py \ 75 | # --experiment_name monosemantic_sweep_all \ 76 | # --experiment_type enumerate_monosemantic \ 77 | # --model "$model" \ 78 | # --feature_dataset text_features.pyth.256.10000 \ 79 | # --normalize_activations \ 80 | # --seed 42 \ 81 | # --save_features_together 82 | 83 | 84 | # python -u probing_experiment.py \ 85 | # --experiment_name monosemantic_sweep_all \ 86 | # --experiment_type enumerate_monosemantic \ 87 | # --model "$model" \ 88 | # --feature_dataset ewt.pyth.512.-1 \ 89 | # --normalize_activations \ 90 | # --seed 42 \ 91 | # --save_features_together 92 | 93 | 94 | # python -u probing_experiment.py \ 95 | # --experiment_name monosemantic_sweep_all \ 96 | # --experiment_type enumerate_monosemantic \ 97 | # --model "$model" \ 98 | # --feature_dataset latex.pyth.1024.-1 \ 99 | # --normalize_activations \ 100 | # --seed 42 \ 101 | # --save_features_together 102 | 103 | python -u probing_experiment.py \ 104 | --experiment_name monosemantic_sweep_all \ 105 | --experiment_type enumerate_monosemantic \ 106 | --feature_dataset wikidata_sorted_is_alive.pyth.128.6000 \ 107 | --model $model \ 108 | --activation_aggregation max \ 109 | --normalize_activations \ 110 | --seed 42 \ 111 | --save_features_together 112 | 113 | python -u probing_experiment.py \ 114 | --experiment_name monosemantic_sweep_all \ 115 | --experiment_type enumerate_monosemantic \ 116 | --feature_dataset wikidata_sorted_occupation.pyth.128.6000 \ 117 | --model $model \ 118 | --activation_aggregation max \ 119 | --normalize_activations \ 120 | --seed 42 \ 121 | --save_features_together 122 | 123 | python -u probing_experiment.py \ 124 | --experiment_name monosemantic_sweep_all \ 125 | --experiment_type enumerate_monosemantic \ 126 | --feature_dataset wikidata_sorted_occupation_athlete.pyth.128.5000 \ 127 | --model $model \ 128 | --activation_aggregation max \ 129 | --normalize_activations \ 130 | --seed 42 \ 131 | --save_features_together 132 | 133 | python -u probing_experiment.py \ 134 | --experiment_name monosemantic_sweep_all \ 135 | --experiment_type enumerate_monosemantic \ 136 | --feature_dataset wikidata_sorted_political_party.pyth.128.3000 \ 137 | --model $model \ 138 | --activation_aggregation max \ 139 | --normalize_activations \ 140 | --seed 42 \ 141 | --save_features_together 142 | 143 | python -u probing_experiment.py \ 144 | --experiment_name monosemantic_sweep_all \ 145 | --experiment_type enumerate_monosemantic \ 146 | --feature_dataset wikidata_sorted_sex_or_gender.pyth.128.6000 \ 147 | --model $model \ 148 | --activation_aggregation max \ 149 | --normalize_activations \ 150 | --seed 42 \ 151 | --save_features_together 152 | done -------------------------------------------------------------------------------- /scripts/experiments/feature_selection.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 4 5 | #SBATCH -o log/%j-%a-feature_selection.log 6 | #SBATCH -a 1-32 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | 27 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 28 | 29 | 30 | for model in "${PYTHIA_MODELS[@]}" 31 | do 32 | python -u probing_experiment.py \ 33 | --experiment_name test_full_feature_selection_comparison \ 34 | --experiment_type compare_feature_selection \ 35 | --model "$model" \ 36 | --feature_dataset programming_lang_id.pyth.512.-1 \ 37 | --activation_aggregation mean \ 38 | --normalize_activations \ 39 | --seed 42 \ 40 | --max_k 64 \ 41 | --save_features_together 42 | 43 | 44 | python -u probing_experiment.py \ 45 | --experiment_name test_full_feature_selection_comparison \ 46 | --experiment_type compare_feature_selection \ 47 | --model "$model" \ 48 | --feature_dataset distribution_id.pyth.512.-1 \ 49 | --activation_aggregation mean \ 50 | --normalize_activations \ 51 | --seed 42 \ 52 | --max_k 64 \ 53 | --save_features_together 54 | 55 | 56 | python -u probing_experiment.py \ 57 | --experiment_name test_full_feature_selection_comparison \ 58 | --experiment_type compare_feature_selection \ 59 | --model "$model" \ 60 | --feature_dataset natural_lang_id.pyth.512.-1 \ 61 | --activation_aggregation mean \ 62 | --normalize_activations \ 63 | --seed 42 \ 64 | --max_k 64 \ 65 | --save_features_together 66 | 67 | 68 | python -u probing_experiment.py \ 69 | --experiment_name test_full_feature_selection_comparison \ 70 | --experiment_type compare_feature_selection \ 71 | --model "$model" \ 72 | --feature_dataset compound_words.pyth.24.-1 \ 73 | --normalize_activations \ 74 | --seed 42 \ 75 | --max_k 64 \ 76 | --save_features_together 77 | 78 | 79 | python -u probing_experiment.py \ 80 | --experiment_name test_full_feature_selection_comparison \ 81 | --experiment_type compare_feature_selection \ 82 | --model "$model" \ 83 | --feature_dataset text_features.pyth.256.10000 \ 84 | --normalize_activations \ 85 | --seed 42 \ 86 | --max_k 64 \ 87 | --save_features_together 88 | 89 | 90 | python -u probing_experiment.py \ 91 | --experiment_name test_full_feature_selection_comparison \ 92 | --experiment_type compare_feature_selection \ 93 | --model "$model" \ 94 | --feature_dataset ewt.pyth.512.-1 \ 95 | --normalize_activations \ 96 | --seed 42 \ 97 | --max_k 64 \ 98 | --save_features_together 99 | 100 | 101 | python -u probing_experiment.py \ 102 | --experiment_name test_full_feature_selection_comparison \ 103 | --experiment_type compare_feature_selection \ 104 | --model "$model" \ 105 | --feature_dataset latex.pyth.1024.-1 \ 106 | --normalize_activations \ 107 | --seed 42 \ 108 | --max_k 64 \ 109 | --save_features_together 110 | done -------------------------------------------------------------------------------- /scripts/experiments/osp_full.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 16 5 | #SBATCH -o log/%j-%a-osp_all.log 6 | #SBATCH -a 1-16 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | module load gurobi/gurobi-1000 27 | 28 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 29 | 30 | 31 | for model in "${PYTHIA_MODELS[@]}" 32 | do 33 | python -u probing_experiment.py \ 34 | --experiment_name test_full_feature_selection_comparison \ 35 | --experiment_type optimal_sparse_probing \ 36 | --model "$model" \ 37 | --feature_dataset programming_lang_id.pyth.512.-1 \ 38 | --activation_aggregation mean \ 39 | --normalize_activations \ 40 | --seed 42 \ 41 | --save_features_together 42 | 43 | 44 | python -u probing_experiment.py \ 45 | --experiment_name test_full_feature_selection_comparison \ 46 | --experiment_type optimal_sparse_probing \ 47 | --model "$model" \ 48 | --feature_dataset distribution_id.pyth.512.-1 \ 49 | --activation_aggregation mean \ 50 | --normalize_activations \ 51 | --seed 42 \ 52 | --save_features_together 53 | 54 | 55 | python -u probing_experiment.py \ 56 | --experiment_name test_full_feature_selection_comparison \ 57 | --experiment_type optimal_sparse_probing \ 58 | --model "$model" \ 59 | --feature_dataset natural_lang_id.pyth.512.-1 \ 60 | --activation_aggregation mean \ 61 | --normalize_activations \ 62 | --seed 42 \ 63 | --save_features_together 64 | 65 | 66 | python -u probing_experiment.py \ 67 | --experiment_name test_full_feature_selection_comparison \ 68 | --experiment_type optimal_sparse_probing \ 69 | --model "$model" \ 70 | --feature_dataset compound_words.pyth.24.-1 \ 71 | --normalize_activations \ 72 | --seed 42 \ 73 | --save_features_together 74 | 75 | 76 | python -u probing_experiment.py \ 77 | --experiment_name test_full_feature_selection_comparison \ 78 | --experiment_type optimal_sparse_probing \ 79 | --model "$model" \ 80 | --feature_dataset text_features.pyth.256.10000 \ 81 | --normalize_activations \ 82 | --seed 42 \ 83 | --save_features_together 84 | 85 | 86 | python -u probing_experiment.py \ 87 | --experiment_name test_full_feature_selection_comparison \ 88 | --experiment_type optimal_sparse_probing \ 89 | --model "$model" \ 90 | --feature_dataset ewt.pyth.512.-1 \ 91 | --normalize_activations \ 92 | --feature_subset not-dep \ 93 | --seed 42 \ 94 | --save_features_together 95 | 96 | 97 | python -u probing_experiment.py \ 98 | --experiment_name test_full_feature_selection_comparison \ 99 | --experiment_type optimal_sparse_probing \ 100 | --model "$model" \ 101 | --feature_dataset latex.pyth.1024.-1 \ 102 | --normalize_activations \ 103 | --seed 42 \ 104 | --save_features_together 105 | done -------------------------------------------------------------------------------- /scripts/experiments/sparsity_sweep_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 4 5 | #SBATCH -o log/%j-%a-full_sparsity_sweep_final.log 6 | #SBATCH -a 1-32 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | module load gurobi/gurobi-1000 27 | 28 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 29 | 30 | 31 | for model in "${PYTHIA_MODELS[@]}" 32 | do 33 | python -u probing_experiment.py \ 34 | --experiment_name full_sparsity_sweep_final \ 35 | --experiment_type rotation_baseline_dxd telescopic_sparsity_sweep \ 36 | --model "$model" \ 37 | --feature_dataset programming_lang_id.pyth.512.-1 \ 38 | --activation_aggregation mean \ 39 | --normalize_activations \ 40 | --seed 42 \ 41 | --save_features_together 42 | 43 | 44 | python -u probing_experiment.py \ 45 | --experiment_name full_sparsity_sweep_final \ 46 | --experiment_type rotation_baseline_dxd telescopic_sparsity_sweep \ 47 | --model "$model" \ 48 | --feature_dataset distribution_id.pyth.512.-1 \ 49 | --activation_aggregation mean \ 50 | --normalize_activations \ 51 | --seed 42 \ 52 | --save_features_together 53 | 54 | 55 | python -u probing_experiment.py \ 56 | --experiment_name full_sparsity_sweep_final \ 57 | --experiment_type rotation_baseline_dxd telescopic_sparsity_sweep \ 58 | --model "$model" \ 59 | --feature_dataset natural_lang_id.pyth.512.-1 \ 60 | --activation_aggregation mean \ 61 | --normalize_activations \ 62 | --seed 42 \ 63 | --save_features_together 64 | 65 | 66 | python -u probing_experiment.py \ 67 | --experiment_name full_sparsity_sweep_final \ 68 | --experiment_type rotation_baseline_dxd telescopic_sparsity_sweep \ 69 | --model "$model" \ 70 | --feature_dataset compound_words.pyth.24.-1 \ 71 | --normalize_activations \ 72 | --seed 42 \ 73 | --save_features_together 74 | 75 | 76 | python -u probing_experiment.py \ 77 | --experiment_name full_sparsity_sweep_final \ 78 | --experiment_type rotation_baseline_dxd telescopic_sparsity_sweep \ 79 | --model "$model" \ 80 | --feature_dataset text_features.pyth.256.10000 \ 81 | --normalize_activations \ 82 | --seed 42 \ 83 | --save_features_together 84 | 85 | 86 | python -u probing_experiment.py \ 87 | --experiment_name full_sparsity_sweep_final \ 88 | --experiment_type rotation_baseline_dxd telescopic_sparsity_sweep \ 89 | --model "$model" \ 90 | --feature_dataset ewt.pyth.512.-1 \ 91 | --normalize_activations \ 92 | --seed 42 \ 93 | --save_features_together 94 | 95 | 96 | python -u probing_experiment.py \ 97 | --experiment_name full_sparsity_sweep_final \ 98 | --experiment_type rotation_baseline_dxd telescopic_sparsity_sweep \ 99 | --model "$model" \ 100 | --feature_dataset latex.pyth.1024.-1 \ 101 | --normalize_activations \ 102 | --seed 42 \ 103 | --save_features_together 104 | 105 | 106 | python -u probing_experiment.py \ 107 | --experiment_name full_sparsity_sweep_final \ 108 | --experiment_type rotation_baseline_dxd telescopic_sparsity_sweep \ 109 | --feature_dataset wikidata_sorted_is_alive.pyth.128.6000 \ 110 | --model $model \ 111 | --activation_aggregation max \ 112 | --normalize_activations \ 113 | --seed 42 \ 114 | --save_features_together 115 | 116 | python -u probing_experiment.py \ 117 | --experiment_name full_sparsity_sweep_final \ 118 | --experiment_type rotation_baseline_dxd telescopic_sparsity_sweep \ 119 | --feature_dataset wikidata_sorted_occupation.pyth.128.6000 \ 120 | --model $model \ 121 | --activation_aggregation max \ 122 | --normalize_activations \ 123 | --seed 42 \ 124 | --save_features_together 125 | 126 | python -u probing_experiment.py \ 127 | --experiment_name full_sparsity_sweep_final \ 128 | --experiment_type rotation_baseline_dxd telescopic_sparsity_sweep \ 129 | --feature_dataset wikidata_sorted_occupation_athlete.pyth.128.5000 \ 130 | --model $model \ 131 | --activation_aggregation max \ 132 | --normalize_activations \ 133 | --seed 42 \ 134 | --save_features_together 135 | 136 | python -u probing_experiment.py \ 137 | --experiment_name full_sparsity_sweep_final \ 138 | --experiment_type rotation_baseline_dxd telescopic_sparsity_sweep \ 139 | --feature_dataset wikidata_sorted_political_party.pyth.128.3000 \ 140 | --model $model \ 141 | --activation_aggregation max \ 142 | --normalize_activations \ 143 | --seed 42 \ 144 | --save_features_together 145 | 146 | python -u probing_experiment.py \ 147 | --experiment_name full_sparsity_sweep_final \ 148 | --experiment_type rotation_baseline_dxd telescopic_sparsity_sweep \ 149 | --feature_dataset wikidata_sorted_sex_or_gender.pyth.128.6000 \ 150 | --model $model \ 151 | --activation_aggregation max \ 152 | --normalize_activations \ 153 | --seed 42 \ 154 | --save_features_together 155 | done -------------------------------------------------------------------------------- /scripts/experiments/superposition_compound_words.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 4 5 | #SBATCH -o log/%j-compound_superposition.log 6 | #SBATCH -a 1-32 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | 27 | PYTHIA_MODELS=('pythia-1b' 'pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 28 | 29 | for model in "${PYTHIA_MODELS[@]}" 30 | do 31 | python -u probing_experiment.py \ 32 | --experiment_name compound_superposition_final \ 33 | --experiment_type iterative_pruning_rotation_baseline_dxd rotation_baseline_dxd telescopic_sparsity_sweep \ 34 | --model "$model" \ 35 | --feature_dataset compound_words.pyth.24.-1 \ 36 | --normalize_activations \ 37 | --seed 42 \ 38 | --save_features_together 39 | done -------------------------------------------------------------------------------- /scripts/lr_hparam_tuning.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 6 5 | #SBATCH -o log/%j-%a-tune_telescoping_sparsity_sweep.log 6 | #SBATCH -a 1-16 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | module load gurobi/gurobi-1000 27 | 28 | 29 | python probing_experiment.py \ 30 | --experiment_name lr_hparam_tuning \ 31 | --experiment_type tune_telescoping_sparsity_sweep \ 32 | --model pythia-1b \ 33 | --feature_dataset programming_lang_id.pyth.512.-1 \ 34 | --activation_aggregation mean \ 35 | --normalize_activations 36 | 37 | python probing_experiment.py \ 38 | --experiment_name lr_hparam_tuning \ 39 | --experiment_type tune_telescoping_sparsity_sweep \ 40 | --model pythia-6.9b \ 41 | --feature_dataset distribution_id.pyth.512.-1 \ 42 | --activation_aggregation mean \ 43 | --normalize_activations 44 | 45 | python probing_experiment.py \ 46 | --experiment_name lr_hparam_tuning \ 47 | --experiment_type tune_telescoping_sparsity_sweep \ 48 | --model pythia-410m \ 49 | --feature_dataset text_features.pyth.256.10000 \ 50 | --normalize_activations 51 | 52 | python probing_experiment.py \ 53 | --experiment_name lr_hparam_tuning \ 54 | --experiment_type tune_telescoping_sparsity_sweep \ 55 | --model pythia-160m \ 56 | --feature_dataset ewt.pyth.512.-1 \ 57 | --normalize_activations \ 58 | --feature_subset morph 59 | 60 | 61 | python probing_experiment.py \ 62 | --experiment_name lr_hparam_tuning \ 63 | --experiment_type tune_telescoping_sparsity_sweep \ 64 | --model pythia-1.4b \ 65 | --feature_dataset latex.pyth.1024.-1 \ 66 | --normalize_activations 67 | 68 | 69 | python probing_experiment.py \ 70 | --experiment_name lr_hparam_tuning \ 71 | --experiment_type tune_telescoping_sparsity_sweep \ 72 | --model pythia-2.8b \ 73 | --feature_dataset compound_words.pyth.24.-1 \ 74 | --normalize_activations \ 75 | --feature_subset social-security,trial-court,mental-health -------------------------------------------------------------------------------- /scripts/make_all_feature_datasets.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python make_feature_datasets.py \ 4 | --feature_collection programming_lang_id 5 | 6 | python make_feature_datasets.py \ 7 | --feature_collection natural_lang_id \ 8 | --ignore_first_k 1 \ 9 | --lang_id_n_tokens 2 10 | 11 | python make_feature_datasets.py \ 12 | --feature_collection distribution_id 13 | 14 | python make_feature_datasets.py \ 15 | --feature_collection position \ 16 | --n_seqs 10000 \ 17 | --seq_len 1024 18 | 19 | python make_feature_datasets.py \ 20 | --feature_collection text_features \ 21 | --seq_len 256 \ 22 | --n_seqs 10000 23 | 24 | python make_feature_datasets.py \ 25 | --feature_collection ewt 26 | 27 | python make_feature_datasets.py \ 28 | --feature_collection compound_words \ 29 | --seq_len 24 30 | 31 | python make_feature_datasets.py \ 32 | --feature_collection latex \ 33 | --seq_len 1024 34 | -------------------------------------------------------------------------------- /scripts/osp_hparam_tuning.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 24 5 | #SBATCH -o log/%j-%a-osp_tuning.log 6 | #SBATCH -a 1-12 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | module load gurobi/gurobi-1000 27 | 28 | 29 | python probing_experiment.py \ 30 | --experiment_name osp_hparam_tuning \ 31 | --experiment_type osp_tuning \ 32 | --model pythia-1b \ 33 | --feature_dataset programming_lang_id.pyth.512.-1 \ 34 | --activation_aggregation mean \ 35 | --osp_upto_k 8 \ 36 | --normalize_activations \ 37 | --gurobi_timeout 90 38 | 39 | python probing_experiment.py \ 40 | --experiment_name osp_hparam_tuning \ 41 | --experiment_type osp_tuning \ 42 | --model pythia-6.9b \ 43 | --feature_dataset distribution_id.pyth.512.-1 \ 44 | --activation_aggregation mean \ 45 | --osp_upto_k 8 \ 46 | --normalize_activations \ 47 | --gurobi_timeout 90 48 | 49 | python probing_experiment.py \ 50 | --experiment_name osp_hparam_tuning \ 51 | --experiment_type osp_tuning \ 52 | --model pythia-410m \ 53 | --feature_dataset text_features.pyth.256.10000 \ 54 | --osp_upto_k 8 \ 55 | --normalize_activations \ 56 | --gurobi_timeout 90 57 | 58 | python probing_experiment.py \ 59 | --experiment_name osp_hparam_tuning \ 60 | --experiment_type osp_tuning \ 61 | --model pythia-160m \ 62 | --feature_dataset ewt.pyth.512.-1 \ 63 | --osp_upto_k 8 \ 64 | --normalize_activations \ 65 | --feature_subset morph \ 66 | --gurobi_timeout 60 67 | 68 | 69 | python probing_experiment.py \ 70 | --experiment_name osp_hparam_tuning \ 71 | --experiment_type osp_tuning \ 72 | --model pythia-1.4b \ 73 | --feature_dataset latex.pyth.1024.-1 \ 74 | --osp_upto_k 8 \ 75 | --normalize_activations \ 76 | --gurobi_timeout 90 77 | 78 | 79 | python probing_experiment.py \ 80 | --experiment_name osp_hparam_tuning \ 81 | --experiment_type osp_tuning \ 82 | --model pythia-2.8b \ 83 | --feature_dataset compound_words.pyth.24.-1 \ 84 | --osp_upto_k 8 \ 85 | --normalize_activations \ 86 | --feature_subset social-security,trial-court,mental-health \ 87 | --gurobi_timeout 60 -------------------------------------------------------------------------------- /scripts/probe_all_context_features.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 24 5 | #SBATCH -o log/all_context_features.log-%j-%a 6 | #SBATCH -a 1-12 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | module load gurobi/gurobi-1000 27 | 28 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b' 'pythia-12b') 29 | PYTHIA_MODELS_MEDIUM=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b') 30 | PYTHIA_MODELS_LARGE=('pythia-2.8b' 'pythia-6.9b') 31 | PYTHIA_MODELS_XL=('pythia-12b') 32 | 33 | CONTEXT_FEATURE_DATASETS=('programming_lang_id.pyth.512.-1' 'distribution_id.pyth.512.-1' 'natural_lang_id.pyth.512.-1') 34 | 35 | for model in "${PYTHIA_MODELS_MEDIUM[@]}" 36 | do 37 | python -u probing_experiment.py \ 38 | --experiment_name all_context_features \ 39 | --experiment_type enumerate_monosemantic fast_heuristic_sparsity_sweep \ 40 | --model "$model" \ 41 | --feature_dataset programming_lang_id.pyth.512.-1 \ 42 | --activation_aggregation mean \ 43 | --normalize_activations \ 44 | --seed 42 \ 45 | --save_features_together 46 | 47 | python -u probing_experiment.py \ 48 | --experiment_name all_context_features \ 49 | --experiment_type enumerate_monosemantic fast_heuristic_sparsity_sweep \ 50 | --model "$model" \ 51 | --feature_dataset distribution_id.pyth.512.-1 \ 52 | --activation_aggregation mean \ 53 | --normalize_activations \ 54 | --seed 42 \ 55 | --save_features_together 56 | 57 | python -u probing_experiment.py \ 58 | --experiment_name all_context_features \ 59 | --experiment_type enumerate_monosemantic fast_heuristic_sparsity_sweep \ 60 | --model "$model" \ 61 | --feature_dataset natural_lang_id.pyth.512.-1 \ 62 | --activation_aggregation mean \ 63 | --normalize_activations \ 64 | --seed 42 \ 65 | --save_features_together 66 | done -------------------------------------------------------------------------------- /scripts/probing_dataset_activations/make_compound_word_features_apd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o log/%j-apd_compound_words.log 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | 6 | # set environment variables 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | # activate environment and load modules 21 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 22 | 23 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 24 | 25 | for model in "${PYTHIA_MODELS[@]}" 26 | do 27 | python -u get_activations.py \ 28 | --experiment_name compound_apd \ 29 | --experiment_type activation_probe_dataset \ 30 | --model "$model" \ 31 | --feature_dataset compound_words.pyth.24.-1 \ 32 | --seed 42 33 | done -------------------------------------------------------------------------------- /scripts/probing_dataset_activations/make_context_features_apd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o log/%j-all_context_features.log 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | 6 | # set environment variables 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | # activate environment and load modules 21 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 22 | 23 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 24 | 25 | for model in "${PYTHIA_MODELS[@]}" 26 | do 27 | python -u get_activations.py \ 28 | --experiment_name all_context_features_test \ 29 | --experiment_type activation_probe_dataset \ 30 | --model "$model" \ 31 | --feature_dataset programming_lang_id.pyth.512.-1 \ 32 | --activation_aggregation mean \ 33 | --seed 42 34 | 35 | python -u get_activations.py \ 36 | --experiment_name all_context_features_test \ 37 | --experiment_type activation_probe_dataset \ 38 | --model "$model" \ 39 | --feature_dataset distribution_id.pyth.512.-1 \ 40 | --activation_aggregation mean \ 41 | --seed 42 42 | 43 | python -u get_activations.py \ 44 | --experiment_name all_context_features_test \ 45 | --experiment_type activation_probe_dataset \ 46 | --model "$model" \ 47 | --feature_dataset natural_lang_id.pyth.512.-1 \ 48 | --activation_aggregation mean \ 49 | --seed 42 50 | 51 | python -u get_activations.py \ 52 | --experiment_name all_context_features_test \ 53 | --experiment_type activation_probe_dataset \ 54 | --model "$model" \ 55 | --feature_dataset natural_lang_id.pyth.512.-1 \ 56 | --seed 42 57 | done -------------------------------------------------------------------------------- /scripts/probing_dataset_activations/make_ewt_feature_apd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o log/%j-apd_ewt.log 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | 6 | # set environment variables 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | # activate environment and load modules 21 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 22 | 23 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 24 | 25 | for model in "${PYTHIA_MODELS[@]}" 26 | do 27 | python -u get_activations.py \ 28 | --experiment_name ewt_apd \ 29 | --experiment_type activation_probe_dataset \ 30 | --model "$model" \ 31 | --feature_dataset ewt.pyth.512.-1 \ 32 | --seed 42 33 | done -------------------------------------------------------------------------------- /scripts/probing_dataset_activations/make_latex_feature_apd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o log/%j-apd_latex.log 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | 6 | # set environment variables 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | # activate environment and load modules 21 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 22 | 23 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 24 | 25 | for model in "${PYTHIA_MODELS[@]}" 26 | do 27 | python -u get_activations.py \ 28 | --experiment_name latex_apd \ 29 | --experiment_type activation_probe_dataset \ 30 | --model "$model" \ 31 | --feature_dataset latex.pyth.1024.-1 \ 32 | --batch_size 8 \ 33 | --seed 42 34 | done -------------------------------------------------------------------------------- /scripts/probing_dataset_activations/make_position_feature_apd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o log/%j-apd_position.log 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | 6 | # set environment variables 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | # activate environment and load modules 21 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 22 | 23 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 24 | 25 | for model in "${PYTHIA_MODELS[@]}" 26 | do 27 | python -u get_activations.py \ 28 | --experiment_name position_apd \ 29 | --experiment_type activation_probe_dataset \ 30 | --model "$model" \ 31 | --feature_dataset position.pyth.1024.10000 \ 32 | --seed 42 33 | done -------------------------------------------------------------------------------- /scripts/probing_dataset_activations/make_text_features_apd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -o log/%j-apd_text.log 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | 6 | # set environment variables 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | # activate environment and load modules 21 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 22 | 23 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 24 | 25 | for model in "${PYTHIA_MODELS[@]}" 26 | do 27 | python -u get_activations.py \ 28 | --experiment_name text_apd \ 29 | --experiment_type activation_probe_dataset \ 30 | --model "$model" \ 31 | --feature_dataset text_features.pyth.256.10000 \ 32 | --seed 42 33 | done -------------------------------------------------------------------------------- /scripts/probing_dataset_activations/make_wikidata_apd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -c 20 3 | #SBATCH --gres=gpu:volta:1 4 | #SBATCH -o log/%j-wikidata-apd 5 | 6 | # set environment variables 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | # activate environment and load modules 21 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 22 | source /etc/profile 23 | 24 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 25 | # 'pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b') 26 | #n_layers: 6 12 24 16 24 32 32 27 | 28 | for model in "${PYTHIA_MODELS[@]}" 29 | do 30 | python -u get_activations.py \ 31 | --experiment_name wikidata_apd \ 32 | --experiment_type activation_probe_dataset \ 33 | --feature_dataset wikidata_sorted_is_alive.pyth.128.6000 \ 34 | --model $model \ 35 | --activation_aggregation max 36 | 37 | python -u get_activations.py \ 38 | --experiment_name wikidata_apd \ 39 | --experiment_type activation_probe_dataset \ 40 | --feature_dataset wikidata_sorted_occupation.pyth.128.6000 \ 41 | --model $model \ 42 | --activation_aggregation max 43 | 44 | python -u get_activations.py \ 45 | --experiment_name wikidata_apd \ 46 | --experiment_type activation_probe_dataset \ 47 | --feature_dataset wikidata_sorted_occupation_athlete.pyth.128.5000 \ 48 | --model $model \ 49 | --activation_aggregation max 50 | 51 | python -u get_activations.py \ 52 | --experiment_name wikidata_apd \ 53 | --experiment_type activation_probe_dataset \ 54 | --feature_dataset wikidata_sorted_political_party.pyth.128.3000 \ 55 | --model $model \ 56 | --activation_aggregation max 57 | 58 | python -u get_activations.py \ 59 | --experiment_name wikidata_apd \ 60 | --experiment_type activation_probe_dataset \ 61 | --feature_dataset wikidata_sorted_sex_or_gender.pyth.128.6000 \ 62 | --model $model \ 63 | --activation_aggregation max 64 | done 65 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 24 5 | #SBATCH -o log/runtest.log-%j-%a 6 | #SBATCH -a 1-16 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | module load gurobi/gurobi-951 27 | 28 | 29 | PYTHIA_MODELS=('pythia-19m' 'pythia-125m' 'pythia-350m' 'pythia-800m' 'pythia-1.3b') 30 | 31 | # Text features sweep across all models 32 | for model in "${PYTHIA_MODELS[@]}" 33 | do 34 | python probing_experiment.py \ 35 | --experiment_name code_lang_max_test \ 36 | --experiment_type heuristic_sparsity_sweep\ 37 | --model "$model" \ 38 | --feature_dataset programming_lang_id.pyth.512.-1 \ 39 | --activation_aggregation max 40 | done 41 | 42 | 43 | for model in "${PYTHIA_MODELS[@]}" 44 | do 45 | python probing_experiment.py \ 46 | --experiment_name nat_lang_test \ 47 | --experiment_type heuristic_sparsity_sweep\ 48 | --model "$model" \ 49 | --feature_dataset natural_lang_id.pyth.512.-1 50 | done 51 | 52 | 53 | 54 | # python probing_experiment.py \ 55 | # --experiment_name language_id_test \ 56 | # --experiment_type heuristic_sparsity_sweep\ 57 | # --model pythia-19m \ 58 | # --feature_dataset github_lang_id.test.pyth.512.-1.True 59 | # --average_seq_activations 60 | 61 | 62 | # python probing_experiment.py \ 63 | # --experiment_name heuristic_feature_selection_test \ 64 | # --experiment_type test_heuristic_filtering \ 65 | # --feature_datasets true_binary_token_supervised_feature_datasets \ 66 | # --n_seqs 1000 \ 67 | # --osp_upto_k 12 \ 68 | # --gurobi_timeout 600 69 | -------------------------------------------------------------------------------- /scripts/run_code_lang_id_experiment.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 24 5 | #SBATCH -o log/runtest.log-%j-%a 6 | #SBATCH -a 1-16 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | module load gurobi/gurobi-951 27 | 28 | 29 | PYTHIA_MODELS=('pythia-19m' 'pythia-125m' 'pythia-350m' 'pythia-800m' 'pythia-1.3b') 30 | 31 | python probing_experiment.py \ 32 | --experiment_name code_lang_max_test \ 33 | --experiment_type heuristic_sparsity_sweep\ 34 | --model pythia-800m \ 35 | --feature_dataset programming_lang_id.pyth.512.-1 \ 36 | --activation_aggregation max 37 | 38 | python probing_experiment.py \ 39 | --experiment_name code_lang_test \ 40 | --experiment_type heuristic_sparsity_sweep\ 41 | --model pythia-800m \ 42 | --feature_dataset programming_lang_id.pyth.512.-1 \ 43 | --probe_location hook_resid_post 44 | 45 | python probing_experiment.py \ 46 | --experiment_name code_lang_test \ 47 | --experiment_type heuristic_sparsity_sweep\ 48 | --model pythia-800m \ 49 | --feature_dataset programming_lang_id.pyth.512.-1 50 | 51 | 52 | -------------------------------------------------------------------------------- /scripts/run_compound_words.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 16 5 | #SBATCH -o log/compound_words.log-%j-%a 6 | #SBATCH -a 1-12 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | module load gurobi/gurobi-1000 27 | 28 | 29 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b') 30 | 31 | 32 | for model in "${PYTHIA_MODELS[@]}" 33 | do 34 | python probing_experiment.py \ 35 | --experiment_name bigram_next_token_test \ 36 | --experiment_type enumerate_monosemantic fast_heuristic_sparsity_sweep \ 37 | --model "$model" \ 38 | --feature_dataset compound_words.pyth.64.-1 \ 39 | --batch_size 128 \ 40 | --probe_next_token_feature 41 | done -------------------------------------------------------------------------------- /scripts/run_distribution_id.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 24 5 | #SBATCH -o log/runtest.log-%j-%a 6 | #SBATCH -a 1-8 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | module load gurobi/gurobi-1000 27 | 28 | PYTHIA_MODELS=('pythia-800m' 'pythia-1.3b') 29 | 30 | for model in "${PYTHIA_MODELS[@]}" 31 | do 32 | python probing_experiment.py \ 33 | --experiment_name distribution_id_test \ 34 | --experiment_type enumerate_monosemantic heuristic_sparsity_sweep optimal_sparse_probing \ 35 | --model "$model" \ 36 | --feature_dataset distribution_id.pyth.512.-1 \ 37 | --activation_aggregation mean 38 | done -------------------------------------------------------------------------------- /scripts/run_ewt_fast_probe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 24 5 | #SBATCH -o log/ewt_fast_probe.log-%j-%a 6 | #SBATCH -a 1-16 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | module load gurobi/gurobi-1000 27 | 28 | 29 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b') 30 | 31 | for model in "${PYTHIA_MODELS[@]}" 32 | do 33 | python -u probing_experiment.py \ 34 | --experiment_name ewt_test \ 35 | --experiment_type fast_heuristic_sparsity_sweep enumerate_monosemantic \ 36 | --model "$model" \ 37 | --feature_dataset ewt.pyth.512.-1 \ 38 | --save_features_together 39 | done 40 | 41 | for model in "${PYTHIA_MODELS[@]}" 42 | do 43 | python -u probing_experiment.py \ 44 | --experiment_name ewt_test \ 45 | --experiment_type fast_heuristic_sparsity_sweep enumerate_monosemantic \ 46 | --model "$model" \ 47 | --feature_dataset ewt.pyth.512.-1 \ 48 | --save_features_together \ 49 | --probe_next_token_feature 50 | done 51 | -------------------------------------------------------------------------------- /scripts/run_feature_selection.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 24 5 | #SBATCH -o log/feature_selection.log-%j-%a 6 | #SBATCH -a 1-16 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | module load gurobi/gurobi-1000 27 | 28 | 29 | PYTHIA_MODELS=('pythia-70m' 'pythia-1b') 30 | 31 | for model in "${PYTHIA_MODELS[@]}" 32 | do 33 | python -u probing_experiment.py \ 34 | --experiment_name feature_selection_norm_test \ 35 | --experiment_type compare_feature_selection \ 36 | --model "$model" \ 37 | --feature_dataset ewt.pyth.512.-1 \ 38 | --osp_upto_k 8 \ 39 | --gurobi_timeout 60 \ 40 | --save_features_together \ 41 | --normalize_activations 42 | 43 | python -u probing_experiment.py \ 44 | --experiment_name feature_selection_test \ 45 | --experiment_type compare_feature_selection \ 46 | --model "$model" \ 47 | --feature_dataset ewt.pyth.512.-1 \ 48 | --osp_upto_k 8 \ 49 | --gurobi_timeout 60 \ 50 | --save_features_together 51 | 52 | python -u probing_experiment.py \ 53 | --experiment_name feature_selection_norm_test \ 54 | --experiment_type compare_feature_selection \ 55 | --model "$model" \ 56 | --feature_dataset programming_lang_id.pyth.512.-1 \ 57 | --activation_aggregation mean \ 58 | --osp_upto_k 8 \ 59 | --gurobi_timeout 60 \ 60 | --save_features_together \ 61 | --normalize_activations 62 | 63 | python -u probing_experiment.py \ 64 | --experiment_name feature_selection_test \ 65 | --experiment_type compare_feature_selection \ 66 | --model "$model" \ 67 | --feature_dataset programming_lang_id.pyth.512.-1 \ 68 | --activation_aggregation mean \ 69 | --osp_upto_k 8 \ 70 | --gurobi_timeout 60 \ 71 | --save_features_together 72 | done 73 | -------------------------------------------------------------------------------- /scripts/run_iterative_pruning.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 24 5 | #SBATCH -o log/runtest.log-%j-%a 6 | #SBATCH -a 1-24 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | module load gurobi/gurobi-1000 27 | 28 | 29 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b') 30 | 31 | 32 | for model in "${PYTHIA_MODELS[@]}" 33 | do 34 | python probing_experiment.py \ 35 | --experiment_name bigram_test \ 36 | --experiment_type osp_iterative_pruning \ 37 | --model "$model" \ 38 | --feature_dataset compound_words.pyth.16.-1 \ 39 | --batch_size 256 \ 40 | --gurobi_timeout 150 41 | done -------------------------------------------------------------------------------- /scripts/run_position_probe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 24 5 | #SBATCH -o log/position.log-%j-%a 6 | #SBATCH -a 1-12 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | module load gurobi/gurobi-1000 27 | 28 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b') 29 | 30 | for model in "${PYTHIA_MODELS[@]}" 31 | do 32 | python -u probing_experiment.py \ 33 | --experiment_name position_test \ 34 | --experiment_type heuristic_sparse_regression_sweep \ 35 | --model "$model" \ 36 | --feature_dataset position.pyth.1024.10000 \ 37 | --save_features_together 38 | 39 | python -u probing_experiment.py \ 40 | --experiment_name position_test \ 41 | --experiment_type dense_regression_probe \ 42 | --model "$model" \ 43 | --feature_dataset position.pyth.1024.10000 \ 44 | --probe_location hook_resid_post \ 45 | --save_features_together 46 | done -------------------------------------------------------------------------------- /scripts/run_probe_refactor_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 6 5 | #SBATCH -o log/refactor_probe_test.log-%j-%a 6 | #SBATCH -a 1-16 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | module load gurobi/gurobi-1000 27 | 28 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 29 | 30 | for model in "${PYTHIA_MODELS[@]}" 31 | do 32 | python -u probing_experiment.py \ 33 | --experiment_name refactor_test \ 34 | --experiment_type fast_heuristic_sparsity_sweep \ 35 | --model "$model" \ 36 | --feature_dataset programming_lang_id.pyth.512.-1 \ 37 | --activation_aggregation mean \ 38 | --normalize_activations \ 39 | --seed 42 \ 40 | --save_features_together 41 | 42 | 43 | python -u probing_experiment.py \ 44 | --experiment_name refactor_test \ 45 | --experiment_type fast_heuristic_sparsity_sweep \ 46 | --model "$model" \ 47 | --feature_dataset distribution_id.pyth.512.-1 \ 48 | --activation_aggregation mean \ 49 | --normalize_activations \ 50 | --seed 42 \ 51 | --save_features_together 52 | 53 | 54 | python -u probing_experiment.py \ 55 | --experiment_name refactor_test \ 56 | --experiment_type fast_heuristic_sparsity_sweep \ 57 | --model "$model" \ 58 | --feature_dataset natural_lang_id.pyth.512.-1 \ 59 | --activation_aggregation mean \ 60 | --normalize_activations \ 61 | --seed 42 \ 62 | --save_features_together 63 | 64 | 65 | python -u probing_experiment.py \ 66 | --experiment_name refactor_test \ 67 | --experiment_type fast_heuristic_sparsity_sweep \ 68 | --model "$model" \ 69 | --feature_dataset compound_words.pyth.24.-1 \ 70 | --normalize_activations \ 71 | --seed 42 \ 72 | --save_features_together 73 | 74 | 75 | python -u probing_experiment.py \ 76 | --experiment_name refactor_test \ 77 | --experiment_type fast_heuristic_sparsity_sweep \ 78 | --model "$model" \ 79 | --feature_dataset text_features.pyth.256.10000 \ 80 | --normalize_activations \ 81 | --seed 42 \ 82 | --save_features_together 83 | 84 | 85 | python -u probing_experiment.py \ 86 | --experiment_name refactor_test \ 87 | --experiment_type heuristic_sparse_regression_sweep \ 88 | --model "$model" \ 89 | --feature_dataset position.pyth.1024.10000 \ 90 | --normalize_activations \ 91 | --seed 42 \ 92 | --save_features_together 93 | 94 | 95 | python -u probing_experiment.py \ 96 | --experiment_name refactor_test \ 97 | --experiment_type fast_heuristic_sparsity_sweep \ 98 | --model "$model" \ 99 | --feature_dataset ewt.pyth.512.-1 \ 100 | --normalize_activations \ 101 | --seed 42 \ 102 | --save_features_together 103 | 104 | 105 | python -u probing_experiment.py \ 106 | --experiment_name refactor_test \ 107 | --experiment_type fast_heuristic_sparsity_sweep \ 108 | --model "$model" \ 109 | --feature_dataset latex.pyth.1024.-1 \ 110 | --normalize_activations \ 111 | --seed 42 \ 112 | --save_features_together 113 | done -------------------------------------------------------------------------------- /scripts/run_sequence_ablation.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -N 1 3 | #SBATCH -c 20 4 | #SBATCH --gres=gpu:volta:1 5 | #SBATCH -o log/%j-sequence_ablation.log 6 | 7 | 8 | # set environment variables 9 | export PATH=$SPARSE_PROBING_ROOT:$PATH 10 | 11 | export HF_DATASETS_OFFLINE=1 12 | export TRANSFORMERS_OFFLINE=1 13 | 14 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 15 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 16 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 17 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 18 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | 20 | sleep 0.1 # wait for paths to update 21 | 22 | # activate environment and load modules 23 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 24 | 25 | 26 | python run_ablation.py \ 27 | --feature_dataset natural_lang_id.pyth.512.-1 \ 28 | --model pythia-70m \ 29 | --batch_size 64 \ 30 | --neuron_subset_file monosemantic_language_neurons.csv 31 | 32 | 33 | python run_ablation.py \ 34 | --feature_dataset programming_lang_id.pyth.512.-1 \ 35 | --model pythia-1b \ 36 | --batch_size 32 \ 37 | --neuron_subset_file monosemantic_code_neurons.csv 38 | 39 | 40 | python run_ablation.py \ 41 | --feature_dataset distribution_id.pyth.512.-1 \ 42 | --model pythia-6.9b \ 43 | --batch_size 8 \ 44 | --neuron_subset_file monosemantic_distribution_neurons.csv 45 | 46 | 47 | python run_ablation.py \ 48 | --feature_dataset pile_test.pyth.512.-1 \ 49 | --model pythia-70m \ 50 | --batch_size 64 \ 51 | --neuron_subset_file monosemantic_language_neurons.csv 52 | 53 | 54 | python run_ablation.py \ 55 | --feature_dataset pile_test.pyth.512.-1 \ 56 | --model pythia-1b \ 57 | --batch_size 32 \ 58 | --neuron_subset 6,3108 9,7926 10,3855 9,1693 -------------------------------------------------------------------------------- /scripts/run_superposition_experiment.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --mem-per-cpu=4G 3 | #SBATCH -N 1 4 | #SBATCH -c 24 5 | #SBATCH -o log/runtest.log-%j-%a 6 | #SBATCH -a 1-24 7 | 8 | 9 | # set environment variables 10 | export PATH=$SPARSE_PROBING_ROOT:$PATH 11 | 12 | export HF_DATASETS_OFFLINE=1 13 | export TRANSFORMERS_OFFLINE=1 14 | 15 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 16 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 17 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 18 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 19 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 20 | 21 | sleep 0.1 # wait for paths to update 22 | 23 | # activate environment and load modules 24 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 25 | source /etc/profile 26 | module load gurobi/gurobi-1000 27 | 28 | 29 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b') 30 | 31 | 32 | for model in "${PYTHIA_MODELS[@]}" 33 | do 34 | python probing_experiment.py \ 35 | --experiment_name bigram_test \ 36 | --experiment_type enumerate_monosemantic optimal_sparse_probing \ 37 | --model "$model" \ 38 | --feature_dataset compound_words.pyth.16.-1 \ 39 | --batch_size 256 40 | done 41 | 42 | 43 | # python probing_experiment.py \ 44 | # --experiment_name superposition_test350 \ 45 | # --experiment_type enumerate_monosemantic optimal_sparse_probing\ 46 | # --model pythia-350m \ 47 | # --feature_dataset programming_lang_id.pyth.512.-1 \ 48 | # --activation_aggregation mean \ 49 | # --gurobi_timeout 300 50 | 51 | 52 | # python probing_experiment.py \ 53 | # --experiment_name superposition_test350 \ 54 | # --experiment_type enumerate_monosemantic optimal_sparse_probing\ 55 | # --model pythia-350m \ 56 | # --feature_dataset natural_lang_id.pyth.512.-1 \ 57 | # --activation_aggregation mean \ 58 | # --gurobi_timeout 300 59 | -------------------------------------------------------------------------------- /scripts/save_weight_statistics.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -c 20 3 | #SBATCH --gres=gpu:volta:1 4 | #SBATCH -o log/save_weight_stats.log-%j-%a 5 | 6 | 7 | # set environment variables 8 | export PATH=$SPARSE_PROBING_ROOT:$PATH 9 | 10 | export HF_DATASETS_OFFLINE=1 11 | export TRANSFORMERS_OFFLINE=1 12 | 13 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 14 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 15 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 16 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 18 | 19 | sleep 0.1 # wait for paths to update 20 | 21 | # activate environment and load modules 22 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 23 | 24 | PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b' 'pythia-12b') 25 | 26 | for model in "${PYTHIA_MODELS[@]}" 27 | do 28 | python -u save_weight_statistics.py --model "$model" 29 | done 30 | 31 | GPT_MODELS=('gpt2-small' 'gpt2-medium' 'gpt2-large' 'gpt2-xl') 32 | for model in "${GPT_MODELS[@]}" 33 | do 34 | python -u save_weight_statistics.py --model "$model" 35 | done -------------------------------------------------------------------------------- /scripts/wikidata/activations.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -c 20 3 | #SBATCH --gres=gpu:volta:1 4 | #SBATCH -o log/activations.log-%j 5 | 6 | # set environment variables 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | # activate environment and load modules 21 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 22 | source /etc/profile 23 | module load gurobi/gurobi-1000 24 | 25 | #PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 26 | #n_layers: 6 12 24 16 24 32 32 27 | #MODELS=('pythia-70m') 28 | #MODELS=('pythia-160m' 'pythia-410m' 'pythia-1.4b' 'pythia-2.8b') 29 | #MODELS=('pythia-70m' 'pythia-1b' 'pythia-6.9b') 30 | MODELS=('pythia-6.9b') 31 | 32 | FEATURE=occupation # sex_or_gender is_alive occupation political_party occupation_athlete 33 | N_SEQ=6000 34 | 35 | PREFIX=wikidata_full 36 | 37 | for model in "${MODELS[@]}" 38 | do 39 | python get_activations.py \ 40 | --experiment_type activation_probe_dataset \ 41 | --feature_dataset "${PREFIX}_${FEATURE}.pyth.128.${N_SEQ}" \ 42 | --model $model \ 43 | --activation_aggregation max 44 | done 45 | -------------------------------------------------------------------------------- /scripts/wikidata/dataset.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -c 4 3 | #SBATCH -o log/dataset.log-%j-%a 4 | 5 | # set environment variables 6 | export PATH=$SPARSE_PROBING_ROOT:$PATH 7 | 8 | export HF_DATASETS_OFFLINE=1 9 | export TRANSFORMERS_OFFLINE=1 10 | 11 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 12 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 13 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 14 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 15 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | 17 | sleep 0.1 # wait for paths to update 18 | 19 | # activate environment and load modules 20 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 21 | source /etc/profile 22 | 23 | PROPERTY=political_party # sex_or_gender is_alive occupation political_party occupation_athlete 24 | MAX_PER_CLASS=500 25 | N_SEQS=1000 26 | TABLE_SIZE=-1 27 | NUM_PROC=1 # probably should keep this at 1 unless max_name_repeats doesn't matter 28 | MAX_NAME_REPEATS=1 # note that this is per process 29 | MIN_PILE_REPEATS=2 30 | 31 | PREFIX=wikidata_full 32 | 33 | # create the table 34 | python make_feature_datasets.py \ 35 | --feature_collection wikidata \ 36 | --model pythia-70m \ 37 | --seq_len 128 \ 38 | --n_seqs $N_SEQS \ 39 | --num_proc $NUM_PROC \ 40 | --wikidata_table_path "/home/gridsan/groups/maia_mechint/datasets/wikidata/wikidata_pile_test_${TABLE_SIZE}.csv" \ 41 | --dataset_name "${PREFIX}_${PROPERTY}" \ 42 | --wikidata_property $PROPERTY \ 43 | --max_per_class $MAX_PER_CLASS \ 44 | --max_name_repeats $MAX_NAME_REPEATS \ 45 | --min_pile_repeats $MIN_PILE_REPEATS 46 | -------------------------------------------------------------------------------- /scripts/wikidata/neurons.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -c 20 3 | #SBATCH --gres=gpu:volta:1 4 | #SBATCH -o log/neurons.log-%j 5 | 6 | # set environment variables 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | export INTERPRETABLE_NEURONS_DIR=$SPARSE_PROBING_ROOT/interpretable_neurons 18 | 19 | sleep 0.1 # wait for paths to update 20 | 21 | # activate environment and load modules 22 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 23 | source /etc/profile 24 | module load gurobi/gurobi-1000 25 | 26 | #PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 27 | #n_layers: 6 12 24 16 24 32 32 28 | #MODELS=('pythia-70m') 29 | #MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b') 30 | #MODELS=('pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 31 | 32 | MODEL=pythia-6.9b 33 | FEATURE=occupation # sex_or_gender is_alive occupation political_party occupation_athlete 34 | N_SEQ=6000 35 | 36 | PREFIX=wikidata_full 37 | SUBSET_FILE=wikidata.csv 38 | 39 | python get_activations.py \ 40 | --experiment_type activation_subset \ 41 | --feature_dataset "${PREFIX}_${FEATURE}.pyth.128.${N_SEQ}" \ 42 | --model $MODEL \ 43 | --neuron_subset_file $SUBSET_FILE \ 44 | --auto_restrict_neuron_subset_file \ 45 | --output_precision 32 46 | #--skip_computing_token_summary_df 47 | -------------------------------------------------------------------------------- /scripts/wikidata/probe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -c 20 3 | #SBATCH -o log/probe.log-%j-%a 4 | #SBATCH -a 1-8 5 | 6 | # set environment variables 7 | export PATH=$SPARSE_PROBING_ROOT:$PATH 8 | 9 | export HF_DATASETS_OFFLINE=1 10 | export TRANSFORMERS_OFFLINE=1 11 | 12 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 13 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 14 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 15 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 17 | 18 | sleep 0.1 # wait for paths to update 19 | 20 | # activate environment and load modules 21 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 22 | source /etc/profile 23 | module load gurobi/gurobi-1000 24 | 25 | #PYTHIA_MODELS=('pythia-70m' 'pythia-160m' 'pythia-410m' 'pythia-1b' 'pythia-1.4b' 'pythia-2.8b' 'pythia-6.9b') 26 | #n_layers: 6 12 24 16 24 32 32 27 | 28 | MODELS=('pythia-6.9b') 29 | 30 | FEATURE=occupation 31 | N_SEQ=6000 32 | 33 | PREFIX=wikidata_full 34 | 35 | for model in "${MODELS[@]}" 36 | do 37 | python probing_experiment.py \ 38 | --experiment_name wikidata/$FEATURE \ 39 | --model "$model" \ 40 | --feature_dataset "${PREFIX}_${FEATURE}.pyth.128.${N_SEQ}" \ 41 | --experiment_type telescopic_sparsity_sweep enumerate_monosemantic \ 42 | --activation_aggregation max \ 43 | --normalize_activations 44 | done 45 | -------------------------------------------------------------------------------- /scripts/wikidata/table.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH -c 8 3 | #SBATCH -o log/table.log-%j-%a 4 | 5 | # set environment variables 6 | export PATH=$SPARSE_PROBING_ROOT:$PATH 7 | 8 | export HF_DATASETS_OFFLINE=1 9 | export TRANSFORMERS_OFFLINE=1 10 | 11 | export RESULTS_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/results 12 | export FEATURE_DATASET_DIR=/home/gridsan/groups/maia_mechint/sparse_probing/feature_datasets 13 | export TRANSFORMERS_CACHE=/home/gridsan/groups/maia_mechint/models 14 | export HF_DATASETS_CACHE=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 15 | export HF_HOME=/home/gridsan/groups/maia_mechint/sparse_probing/hf_home 16 | 17 | sleep 0.1 # wait for paths to update 18 | 19 | # activate environment and load modules 20 | source $SPARSE_PROBING_ROOT/sparprob/bin/activate 21 | source /etc/profile 22 | 23 | N_LINES=-1 24 | #N_LINES=250 25 | 26 | # create the table 27 | python probing_datasets/wikidata.py \ 28 | --n_lines $N_LINES \ 29 | --raw_path /home/gridsan/groups/maia_mechint/datasets/wikidata/raw/latest-all.json.bz2 \ 30 | --dataset_path /home/gridsan/groups/maia_mechint/datasets/pile-test.hf \ 31 | --output_path "/home/gridsan/groups/maia_mechint/datasets/wikidata/wikidata_pile_test_${N_LINES}.csv" 32 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import random 3 | import numpy as np 4 | import torch 5 | import argparse 6 | import time 7 | 8 | 9 | def seed_all(seed): 10 | random.seed(seed) 11 | np.random.seed(seed) 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed(seed) 14 | 15 | 16 | def timestamp(): 17 | return datetime.datetime.now().strftime("%Y:%m:%d:%H:%M:%S") 18 | 19 | 20 | def default_argument_parser(): 21 | parser = argparse.ArgumentParser( 22 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 23 | # experiment params 24 | parser.add_argument( 25 | '--experiment_name', default=str(int(time.time()) // 10), 26 | help='Name of experiment to save') 27 | parser.add_argument( 28 | '--experiment_type', nargs='+', required=True, 29 | help='The inner loop function(s) to run for the experiment') 30 | parser.add_argument( 31 | '--model', default='pythia-70m', 32 | help='Name of model from TransformerLens') 33 | parser.add_argument( 34 | '--feature_dataset', 35 | help='Name of cached feature dataset') 36 | parser.add_argument( 37 | '--probe_location', default='mlp.hook_post', 38 | help='Model component to probe') 39 | parser.add_argument( 40 | '--activation_aggregation', default=None, 41 | help='Average activations across all tokens in a sequence') 42 | parser.add_argument( 43 | '--seed', default=1, type=int, 44 | help='Random seed for experiment') 45 | parser.add_argument( 46 | '--probe_next_token_feature', action='store_true', 47 | help='Probe the token before the probe_index to predict property of the probe_index') 48 | return parser 49 | 50 | 51 | MODEL_N_LAYERS = { 52 | 'pythia-70m': 6, 53 | 'pythia-160m': 12, 54 | 'pythia-410m': 24, 55 | 'pythia-1b': 16, 56 | 'pythia-1.4b': 24, 57 | 'pythia-2.8b': 32, 58 | 'pythia-6.9b': 32 59 | } 60 | --------------------------------------------------------------------------------